序言

定义模型类的时候,一般都需要继承nn.Module类。当我们后续对模型进行查看或者定位修改的时候很头疼它的api那么多,应该用哪个,怎么用,为什么……这篇博客就好好捋一捋nn.Module

(建议配合官网源码一起阅读。)

源文件第一行: from collections import OrderedDict, namedtuple,导入的这两个包很重要,尤其是OrderedDict,Module类的属性大多数都是这个类型的。OrderedDict就是一个有序的字典类型。

类定义

def __init__(self):
    """
    Initializes internal Module state, shared by both nn.Module and ScriptModule.
    """
    torch._C._log_api_usage_once("python.nn_module")

    self.training = True                             # 重要,设定本组件的模式:训练/测试
    self._parameters = OrderedDict()                 # 模型的参数
    self._buffers = OrderedDict()                    # 缓冲区,存储非参数但又属于模型state
    self._non_persistent_buffers_set = set()         # 配合上一属性使用
    self._backward_hooks = OrderedDict()		     # 反向传播的钩子 
    self._forward_hooks = OrderedDict()				 # 正向传播的钩子
    self._forward_pre_hooks = OrderedDict()			 # 正向传播预定义钩子
    self._state_dict_hooks = OrderedDict()			 # 模型状态钩子
    self._load_state_dict_pre_hooks = OrderedDict()  # 
    self._modules = OrderedDict() 					 # 重要!子模块属性

首先这里只有self.training是公开类属性,其他都加了_来约定这些实例变量都只在类内使用。如果想要使用其实也可以,毕竟_不是强约束。但对一些重要属性,Module类暴露了相应的接口给我们。

类定义中,设置了很多OrderedDict类型的实例变量,其中self._parameters, self._modules这两个属性用的最多,分别存储了本组件的参数和自组件。

这里还定义了4个不同类型的hook属性,它们是实现Module模块几个钩子函数的关键。

或许你可能会困惑这里定义的self._buffers属性有什么用。其实它是self._parameters的补充。 如果你模型有参数应该被保存在state_dict中,但是它用优化器来训练。那么就应该将其存在self._buffers这个属性中。Buffers不会通过model.parameters()返回,所以优化器不能更新它们。
另外persistentnon_persistent的区别在于后者不属于本组件的state_dict属性。

register buffer

def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
    """
    Args:
        name (string): name of the buffer. The buffer can be accessed
            from this module using the given name
        tensor (Tensor): buffer to be registered.
        persistent (bool): whether the buffer is part of this module's
            :attr:`state_dict`.
    """
    if persistent is False and isinstance(self, torch.jit.ScriptModule):
        raise RuntimeError("ScriptModule does not support non-persistent buffers")
        # torch.jit.ScriptModule 几乎用不上,不用管这个
    if '_buffers' not in self.__dict__:
        raise AttributeError(
            "cannot assign buffer before Module.__init__() call")
        # 需要先初始化,毕竟要先有`self._buffers`容器
        # 下面是一些buffer的key的类型检查,存在性检查
        # 注意buffer的val一定是torch.Tensor类型
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("buffer name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("buffer name can't contain \".\"")
    elif name == '':
        raise KeyError("buffer name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._buffers:
        raise KeyError("attribute '{}' already exists".format(name))
    elif tensor is not None and not isinstance(tensor, torch.Tensor):
        raise TypeError("cannot assign '{}' object to buffer '{}' "
                        "(torch Tensor or None required)"
                        .format(torch.typename(tensor), name))
    else:
        self._buffers[name] = tensor
        # 根据persistent属性存储
        if persistent:
            self._non_persistent_buffers_set.discard(name)
        else:
            self._non_persistent_buffers_set.add(name)

在预定义里解释过,需要随模型被保存,但不需要训练的参数(Tensor类型)可以使用本函数进行保存。
比如BatchNorm里的running_mean
但是实际中很少用到这个函数。(

register parameter

这个函数大部分设定和register_buffers相似,下面有部分不同之处的解读。

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
    """
    Args:
        name (string): name of the parameter. The parameter can be accessed
            from this module using the given name
        param (Parameter): parameter to be added to the module.
    """
    if '_parameters' not in self.__dict__:
        # 初始化检查,要先有`self._parameters`容器
        raise AttributeError(
            "cannot assign parameter before Module.__init__() call")
        # 下面是一些parameter的key的类型检查,存在性检查
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("parameter name should be a string. "
                        "Got {}".format(torch.typename(name)))
    elif '.' in name:
        raise KeyError("parameter name can't contain \".\"")
    elif name == '':
        raise KeyError("parameter name can't be empty string \"\"")
    elif hasattr(self, name) and name not in self._parameters:
        raise KeyError("attribute '{}' already exists".format(name))
    # parameter 类型检查
    if param is None:
        self._parameters[name] = None
    elif not isinstance(param, Parameter):
        raise TypeError("cannot assign '{}' object to parameter '{}' "
                        "(torch.nn.Parameter or None required)"
                        .format(torch.typename(param), name))
    elif param.grad_fn:
        # 这里涉及到pytorch的autograd设计机制。最表象的规定就是nn.Parameter必须是一个
        # 叶节点。 如果这个tensor含有`.grad_fn`属性,那么它就不是叶节点,就不能赋值。
        raise ValueError(
            "Cannot assign non-leaf Tensor to parameter '{0}'. Model "
            "parameters must be created explicitly. To express '{0}' "
            "as a function of another Tensor, compute the value in "
            "the forward() method.".format(name))
    else:
        self._parameters[name] = param

又有一个问题,什么时候使用这个函数?(至少我之前没有用过)
上面已经分析过知道parameterbuffer的区别了,parameter是需要优化器训练学习的。 顺着BatchNorm里的running_mean找到nn.modules.batchnorm,这里提供了register_buffers, register_parameters 这两个函数的使用样例。

if self.affine: 
     self.weight = Parameter(torch.Tensor(num_features)) 
     self.bias = Parameter(torch.Tensor(num_features)) 
 else: 
     self.register_parameter('weight', None) 
     self.register_parameter('bias', None) 
 if self.track_running_stats: 
     self.register_buffer('running_mean', torch.zeros(num_features)) 
     self.register_buffer('running_var', torch.ones(num_features)) 
     self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 
 else: 
     self.register_parameter('running_mean', None) 
     self.register_parameter('running_var', None) 
     self.register_parameter('num_batches_tracked', None)

BatchNorm里的running_mean, runing_varmodel.train(), model.eval()两个状态下是不同的!所以可以通过这两个函数进行设置。

至于其他什么场景下会使用到这两个函数,就要看实际需求了。

add module

给当前组件添加子组件,套娃警告 >_>

def add_module(self, name: str, module: Optional['Module']) -> None:
    r"""Adds a child module to the current module.

    The module can be accessed as an attribute using the given name.

    Args:
        name (string): name of the child module. The child module can be
            accessed from this module using the given name
        module (Module): child module to be added to the module.
    """
    # 参数检查
    # 是否是Module
    if not isinstance(module, Module) and module is not None:
        raise TypeError("{} is not a Module subclass".format(
            torch.typename(module)))
    # 子组件名称
    elif not isinstance(name, torch._six.string_classes):
        raise TypeError("module name should be a string. Got {}".format(
            torch.typename(name)))
    # 子组件名称是否重复
    elif hasattr(self, name) and name not in self._modules:
        raise KeyError("attribute '{}' already exists".format(name))
    elif '.' in name:
        raise KeyError("module name can't contain \".\"")
    elif name == '':
        raise KeyError("module name can't be empty string \"\"")
    self._modules[name] = module

_apply

从名字就可以看出,这个函数是个建议只在类内使用的函数。本类中出现次数最多,使用最多的函数。因为 很多有了上面那个函数可以套娃,所以在对组件进行操作的时候,需要对子组件进行同样的操作。

def _apply(self, fn):
    # 循环嵌套
    for module in self.children():
        module._apply(fn)

    def compute_should_use_set_data(tensor, tensor_applied):
        pass  # 内置函数  
    # 对自身的参数进行操作,包括参数本身,梯度
    for key, param in self._parameters.items():
        pass
    # 对buffer内的参数进行操作
    for key, buf in self._buffers.items():
        if buf is not None:
            self._buffers[key] = fn(buf)

    return self

apply

递归调用fn函数到每一个子模块。暴露给用户的接口,经常在初始化模型层参数的时候会用到。相当简单。

def apply(self: T, fn: Callable[['Module'], None]) -> T:
    """
    Args:
        fn (:class:`Module` -> None): function to be applied to each submodule
    Returns:
        Module: self
    """
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

cuda cpu type float…

省略号里还包括了double half bfloat16等函数,这些函数都比较简单。cuda(), cpu()的作用就是 把参数从一个device移到另一个device上。剩下的都是类型转换。都需要用到上面讲到的self._apply()函数。它们返回的都是模块本身。

to

这个函数和cuda(), cpu()类似,可以看到函数内定义了一个子函数convert(t),用于将state迁移 到其他设备上。最后做了一个_apply()函数,应用到所有子模块。

def to(self, *args, **kwargs):
    ...
    def convert(t):
        if convert_to_format is not None and t.dim() == 4:
            return t.to(device, dtype if t.is_floating_point() else None, non_blocking, memory_format=convert_to_format)
        return t.to(device, dtype if t.is_floating_point() else None, non_blocking)

    return self._apply(convert)

hook

有关hook,我在之前的一篇博客里写过相关的介绍,可以移步去看看

魔法方法

名称前后各有两个下划线的方法,具有特殊意义的。这些方法过一遍就好,主要是类本身属性相关的一些内容。

setstate

更新self.__dict__.update(state)

getattr

获取属性,主要包括parameters,buffers,module中的属。首先在self.__dict__找到这几个字典, 然后通过name属性去查找。

def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
    if '_parameters' in self.__dict__:
        _parameters = self.__dict__['_parameters']
        if name in _parameters:
            return _parameters[name]
    if '_buffers' in self.__dict__:
        _buffers = self.__dict__['_buffers']
        if name in _buffers:
            return _buffers[name]
    if '_modules' in self.__dict__:
        modules = self.__dict__['_modules']
        if name in modules:
            return modules[name]
    raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
        type(self).__name__, name))

setattr

__getattr__相对应,设置的时候也分 parameters,buffers,module三类。以parameters为例, 1.先获取已有的params,2.对要设置的value进行类型检查,3.对要插入的name(key)进行存在性检测。 另外两个类似。

def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
        def remove_from(*dicts_or_sets):
            for d in dicts_or_sets:
                if name in d:
                    if isinstance(d, dict):
                        del d[name]
                    else:
                        d.discard(name)

    params = self.__dict__.get('_parameters')
    if isinstance(value, Parameter):
        if params is None:
            raise AttributeError(
                "cannot assign parameters before Module.__init__() call")
        remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
        self.register_parameter(name, value)
    elif params is not None and name in params:
        if value is not None:
            raise TypeError("cannot assign '{}' as parameter '{}' "
                            "(torch.nn.Parameter or None expected)"
                            .format(torch.typename(value), name))
        self.register_parameter(name, value)
    else:
        modules = self.__dict__.get('_modules')
        if isinstance(value, Module):
            pass
        else:
            buffers = self.__dict__.get('_buffers')
            if buffers is not None and name in buffers:
                pass

delattr

参考上面两个函数。

def __delattr__(self, name):
    if name in self._parameters:
        del self._parameters[name]
    elif name in self._buffers:
        del self._buffers[name]
        self._non_persistent_buffers_set.discard(name)
    elif name in self._modules:
        del self._modules[name]
    else:
        object.__delattr__(self, name)

repr extra_repr

extra_repr__repr__的用户补充,即我们可以根据自己的需要自定义模组的repr属性。
__repr__的定义如下:

def __repr__(self):
    # We treat the extra repr like the sub-module, one item per line
    extra_lines = []
    extra_repr = self.extra_repr()
    # empty string will be split into list ['']
    if extra_repr:
        extra_lines = extra_repr.split('\n')
    child_lines = []
    for key, module in self._modules.items():
        mod_str = repr(module)
        mod_str = _addindent(mod_str, 2)
        child_lines.append('(' + key + '): ' + mod_str)
    lines = extra_lines + child_lines

    main_str = self._get_name() + '('
    if lines:
        # simple one-liner info, which most builtin Modules will use
        if len(extra_lines) == 1 and not child_lines:
            main_str += extra_lines[0]
        else:
            main_str += '\n  ' + '\n  '.join(lines) + '\n'

    main_str += ')'
    return main_str

先获取extra_repr表示,然后遍历子module,按照key,构造module名字,根据层级在前面加空格,然后拼接起来。

dir

python类内置的方法,用于返回字母顺序排列的属性列。这里module除了默认属性, 还要加上parameters,buffers,modules这些。 代码很简洁:

def __dir__(self):
    module_attrs = dir(self.__class__)
    attrs = list(self.__dict__.keys())
    parameters = list(self._parameters.keys())
    modules = list(self._modules.keys())
    buffers = list(self._buffers.keys())
    keys = module_attrs + attrs + parameters + modules + buffers

    # Eliminate attrs that are not legal Python variable names
    keys = [key for key in keys if not key[0].isdigit()]

    return sorted(keys)

state_dict load_state_dict

state_dict()返回一个包含整个module状态的字典,包括参数和persistent的buffers。键就是这些参数和buffer的名字。 通过函数描述,我觉得我也能写出这个函数来。)

def state_dict(self, destination=None, prefix='', keep_vars=False):
    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()
    destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
    self._save_to_state_dict(destination, prefix, keep_vars)
    for name, module in self._modules.items():
        if module is not None:
            module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
    for hook in self._state_dict_hooks.values():
        hook_result = hook(self, destination, prefix, local_metadata)
        if hook_result is not None:
            destination = hook_result
    return destination

load_state_dict()也是常用的一个函数,但是要注意当前使用的模型的model.state_dict的key要和导入的这个匹配。 因为源码会对这个进行检查,另外strict参数可以限制是否严格限制状态匹配。

def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
                        strict: bool = True):
    # 检查匹配,缺失的,预期有但是没有的,错误信息。
    missing_keys = []
    unexpected_keys = []
    error_msgs = []
    
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    # 内置函数,使用_load_from_state_dict导入,隐去了处理的细节。
    def load(module, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        module._load_from_state_dict(
            state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + '.')

    load(self)
    load = None  # break load->load reference cycle

    # strict的两种模式,如果设置了且存在error就会引发RuntimeError,否则没有。
    if strict:
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join('"{}"'.format(k) for k in missing_keys)))

    if len(error_msgs) > 0:
        raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
                            self.__class__.__name__, "\n\t".join(error_msgs)))
    return _IncompatibleKeys(missing_keys, unexpected_keys)

_named_members

这是后续两大类方法的一个辅助函数,用于生成多种 name + members of modules.

def _named_members(self, get_members_fn, prefix='', recurse=True):
    # 由于named member 都是Orderdict类型,所以这里是set类型
    memo = set() 
    # 递归获取所有的module
    modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
    for module_prefix, module in modules:
        # 这里 `get_members_fn` 用函数作为变量,根据函数不同动态获取不同的member
        members = get_members_fn(module)
        # 遍历members对象,构造生成器
        for k, v in members:
            if v is None or v in memo:
                continue
            memo.add(v)
            name = module_prefix + ('.' if module_prefix else '') + k
            yield name, v

named_parameters parameters

named_parameters()返回遍历所有模型参数的迭代器,生成 名字+参数。

def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
    # 调用上面的_named_memebers()方法,注意这里传入的`get_members_fn`是个lambda函数
    # lambda接受一个module参数,返回他的_parameters内容(在module定义部分可以找到,这是个OrderDict类型变量)
    gen = self._named_members(
        lambda module: module._parameters.items(),
        prefix=prefix, recurse=recurse)
    for elem in gen:
        yield elem

parameters()named_parameters()的一个简化函数,因为它不需要name属性。 这个函数通常是用来将参数传递给优化器。所以代码也很简洁。

def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
        for name, param in self.named_parameters(recurse=recurse):
            yield param

named_buffers buffers

同上,这里get_members_fnlambda module: module._buffers.item().

children modules

children(), named_children()modules(), named_modules()这两类方法比较类似。
named_children()返回生成直接子模型(即只有下一层)的迭代器,生成 名字 + 子模型。
注意这里并不会递归迭代。

def named_children(self) -> Iterator[Tuple[str, 'Module']]:
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module

named_modules()则是对上面的补充,它会返回模型所有的子模型,即它会递归遍历。

def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
    if memo is None:
        memo = set()
    if self not in memo:
        memo.add(self)
        yield prefix, self
        for name, module in self._modules.items():
            if module is None:
                continue
            submodule_prefix = prefix + ('.' if prefix else '') + name
            # 递归调用
            for m in module.named_modules(memo, submodule_prefix):
                yield m

合理利用这两个函数,我们可以定位到任意我们需要的层级。

train eval

这两个函数源码没有意思,因为只是递归设置了状态。需要再挖深一点才能看到有什么区别。

requiresgrad

用于设置autograd是否需要记录这个组件上参数的操作。
它会设置该组件所有参数的requires_grad参数。
当我们需要freeze某一个组件的时候十分有用,可以用于finetune或者只训练模型的一部分。

def requires_grad_(self: T, requires_grad: bool = True) -> T:
    for p in self.parameters():
        p.requires_grad_(requires_grad)
    return self

zero_grad

设置组件的所有参数的梯度为0,用于下一步更新之类的。正常训练过程中必须要用的一步骤。

def zero_grad(self, set_to_none: bool = False) -> None:
    if getattr(self, '_is_replica', False):
        warnings.warn(
            "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
            "The parameters are copied (in a differentiable manner) from the original module. "
            "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
            "If you need gradients in your forward method, consider using autograd.grad instead.")

    for p in self.parameters():
        if p.grad is not None:
            if set_to_none:
                p.grad = None
            else:
                if p.grad.grad_fn is not None:
                    p.grad.detach_()
                else:
                    p.grad.requires_grad_(False)
                p.grad.zero_()

以上就是nn.Module类的一个大概了。