序言
定义模型类的时候,一般都需要继承nn.Module
类。当我们后续对模型进行查看或者定位修改的时候很头疼它的api那么多,应该用哪个,怎么用,为什么……这篇博客就好好捋一捋nn.Module
。
(建议配合官网源码一起阅读。)
源文件第一行: from collections import OrderedDict, namedtuple
,导入的这两个包很重要,尤其是OrderedDict
,Module类的属性大多数都是这个类型的。OrderedDict
就是一个有序的字典类型。
类定义
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True # 重要,设定本组件的模式:训练/测试
self._parameters = OrderedDict() # 模型的参数
self._buffers = OrderedDict() # 缓冲区,存储非参数但又属于模型state
self._non_persistent_buffers_set = set() # 配合上一属性使用
self._backward_hooks = OrderedDict() # 反向传播的钩子
self._forward_hooks = OrderedDict() # 正向传播的钩子
self._forward_pre_hooks = OrderedDict() # 正向传播预定义钩子
self._state_dict_hooks = OrderedDict() # 模型状态钩子
self._load_state_dict_pre_hooks = OrderedDict() #
self._modules = OrderedDict() # 重要!子模块属性
首先这里只有self.training
是公开类属性,其他都加了_
来约定这些实例变量都只在类内使用。如果想要使用其实也可以,毕竟_
不是强约束。但对一些重要属性,Module
类暴露了相应的接口给我们。
类定义中,设置了很多OrderedDict
类型的实例变量,其中self._parameters, self._modules
这两个属性用的最多,分别存储了本组件的参数和自组件。
这里还定义了4个不同类型的hook
属性,它们是实现Module
模块几个钩子函数的关键。
或许你可能会困惑这里定义的self._buffers
属性有什么用。其实它是self._parameters
的补充。
如果你模型有参数应该被保存在state_dict
中,但是它用优化器来训练。那么就应该将其存在self._buffers
这个属性中。Buffers
不会通过model.parameters()
返回,所以优化器不能更新它们。
另外persistent
和non_persistent
的区别在于后者不属于本组件的state_dict
属性。
register buffer
def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
"""
Args:
name (string): name of the buffer. The buffer can be accessed
from this module using the given name
tensor (Tensor): buffer to be registered.
persistent (bool): whether the buffer is part of this module's
:attr:`state_dict`.
"""
if persistent is False and isinstance(self, torch.jit.ScriptModule):
raise RuntimeError("ScriptModule does not support non-persistent buffers")
# torch.jit.ScriptModule 几乎用不上,不用管这个
if '_buffers' not in self.__dict__:
raise AttributeError(
"cannot assign buffer before Module.__init__() call")
# 需要先初始化,毕竟要先有`self._buffers`容器
# 下面是一些buffer的key的类型检查,存在性检查
# 注意buffer的val一定是torch.Tensor类型
elif not isinstance(name, torch._six.string_classes):
raise TypeError("buffer name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("buffer name can't contain \".\"")
elif name == '':
raise KeyError("buffer name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._buffers:
raise KeyError("attribute '{}' already exists".format(name))
elif tensor is not None and not isinstance(tensor, torch.Tensor):
raise TypeError("cannot assign '{}' object to buffer '{}' "
"(torch Tensor or None required)"
.format(torch.typename(tensor), name))
else:
self._buffers[name] = tensor
# 根据persistent属性存储
if persistent:
self._non_persistent_buffers_set.discard(name)
else:
self._non_persistent_buffers_set.add(name)
在预定义里解释过,需要随模型被保存,但不需要训练的参数(Tensor类型)可以使用本函数进行保存。
比如BatchNorm
里的running_mean
。
但是实际中很少用到这个函数。(
register parameter
这个函数大部分设定和register_buffers
相似,下面有部分不同之处的解读。
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
"""
Args:
name (string): name of the parameter. The parameter can be accessed
from this module using the given name
param (Parameter): parameter to be added to the module.
"""
if '_parameters' not in self.__dict__:
# 初始化检查,要先有`self._parameters`容器
raise AttributeError(
"cannot assign parameter before Module.__init__() call")
# 下面是一些parameter的key的类型检查,存在性检查
elif not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
elif '.' in name:
raise KeyError("parameter name can't contain \".\"")
elif name == '':
raise KeyError("parameter name can't be empty string \"\"")
elif hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))
# parameter 类型检查
if param is None:
self._parameters[name] = None
elif not isinstance(param, Parameter):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or None required)"
.format(torch.typename(param), name))
elif param.grad_fn:
# 这里涉及到pytorch的autograd设计机制。最表象的规定就是nn.Parameter必须是一个
# 叶节点。 如果这个tensor含有`.grad_fn`属性,那么它就不是叶节点,就不能赋值。
raise ValueError(
"Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param
又有一个问题,什么时候使用这个函数?(至少我之前没有用过)
上面已经分析过知道parameter
和buffer
的区别了,parameter
是需要优化器训练学习的。
顺着BatchNorm
里的running_mean
找到nn.modules.batchnorm
,这里提供了register_buffers, register_parameters
这两个函数的使用样例。
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
BatchNorm
里的running_mean, runing_var
在model.train(), model.eval()
两个状态下是不同的!所以可以通过这两个函数进行设置。
至于其他什么场景下会使用到这两个函数,就要看实际需求了。
add module
给当前组件添加子组件,套娃警告 >_>
def add_module(self, name: str, module: Optional['Module']) -> None:
r"""Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (string): name of the child module. The child module can be
accessed from this module using the given name
module (Module): child module to be added to the module.
"""
# 参数检查
# 是否是Module
if not isinstance(module, Module) and module is not None:
raise TypeError("{} is not a Module subclass".format(
torch.typename(module)))
# 子组件名称
elif not isinstance(name, torch._six.string_classes):
raise TypeError("module name should be a string. Got {}".format(
torch.typename(name)))
# 子组件名称是否重复
elif hasattr(self, name) and name not in self._modules:
raise KeyError("attribute '{}' already exists".format(name))
elif '.' in name:
raise KeyError("module name can't contain \".\"")
elif name == '':
raise KeyError("module name can't be empty string \"\"")
self._modules[name] = module
_apply
从名字就可以看出,这个函数是个建议只在类内使用的函数。本类中出现次数最多,使用最多的函数。因为 很多有了上面那个函数可以套娃,所以在对组件进行操作的时候,需要对子组件进行同样的操作。
def _apply(self, fn):
# 循环嵌套
for module in self.children():
module._apply(fn)
def compute_should_use_set_data(tensor, tensor_applied):
pass # 内置函数
# 对自身的参数进行操作,包括参数本身,梯度
for key, param in self._parameters.items():
pass
# 对buffer内的参数进行操作
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
apply
递归调用fn
函数到每一个子模块。暴露给用户的接口,经常在初始化模型层参数的时候会用到。相当简单。
def apply(self: T, fn: Callable[['Module'], None]) -> T:
"""
Args:
fn (:class:`Module` -> None): function to be applied to each submodule
Returns:
Module: self
"""
for module in self.children():
module.apply(fn)
fn(self)
return self
cuda cpu type float…
省略号里还包括了double half bfloat16等函数,这些函数都比较简单。cuda(), cpu()
的作用就是
把参数从一个device移到另一个device上。剩下的都是类型转换。都需要用到上面讲到的self._apply()
函数。它们返回的都是模块本身。
to
这个函数和cuda(), cpu()
类似,可以看到函数内定义了一个子函数convert(t)
,用于将state迁移
到其他设备上。最后做了一个_apply()
函数,应用到所有子模块。
def to(self, *args, **kwargs):
...
def convert(t):
if convert_to_format is not None and t.dim() == 4:
return t.to(device, dtype if t.is_floating_point() else None, non_blocking, memory_format=convert_to_format)
return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
return self._apply(convert)
hook
有关hook
,我在之前的一篇博客里写过相关的介绍,可以移步去看看。
魔法方法
名称前后各有两个下划线的方法,具有特殊意义的。这些方法过一遍就好,主要是类本身属性相关的一些内容。
setstate
更新self.__dict__.update(state)
。
getattr
获取属性,主要包括parameters
,buffers
,module
中的属。首先在self.__dict__
找到这几个字典,
然后通过name
属性去查找。
def __getattr__(self, name: str) -> Union[Tensor, 'Module']:
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_buffers' in self.__dict__:
_buffers = self.__dict__['_buffers']
if name in _buffers:
return _buffers[name]
if '_modules' in self.__dict__:
modules = self.__dict__['_modules']
if name in modules:
return modules[name]
raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
setattr
和__getattr__
相对应,设置的时候也分 parameters
,buffers
,module
三类。以parameters
为例,
1.先获取已有的params,2.对要设置的value进行类型检查,3.对要插入的name(key)进行存在性检测。
另外两个类似。
def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
pass
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
pass
delattr
参考上面两个函数。
def __delattr__(self, name):
if name in self._parameters:
del self._parameters[name]
elif name in self._buffers:
del self._buffers[name]
self._non_persistent_buffers_set.discard(name)
elif name in self._modules:
del self._modules[name]
else:
object.__delattr__(self, name)
repr extra_repr
extra_repr
是__repr__
的用户补充,即我们可以根据自己的需要自定义模组的repr属性。
__repr__
的定义如下:
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split('\n')
child_lines = []
for key, module in self._modules.items():
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
先获取extra_repr
表示,然后遍历子module
,按照key
,构造module
名字,根据层级在前面加空格,然后拼接起来。
dir
python类内置的方法,用于返回字母顺序排列的属性列。这里module
除了默认属性,
还要加上parameters
,buffers
,modules
这些。
代码很简洁:
def __dir__(self):
module_attrs = dir(self.__class__)
attrs = list(self.__dict__.keys())
parameters = list(self._parameters.keys())
modules = list(self._modules.keys())
buffers = list(self._buffers.keys())
keys = module_attrs + attrs + parameters + modules + buffers
# Eliminate attrs that are not legal Python variable names
keys = [key for key in keys if not key[0].isdigit()]
return sorted(keys)
state_dict load_state_dict
state_dict()
返回一个包含整个module状态的字典,包括参数和persistent
的buffers。键就是这些参数和buffer的名字。
通过函数描述,我觉得我也能写出这个函数来。)
def state_dict(self, destination=None, prefix='', keep_vars=False):
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
load_state_dict()
也是常用的一个函数,但是要注意当前使用的模型的model.state_dict
的key要和导入的这个匹配。
因为源码会对这个进行检查,另外strict
参数可以限制是否严格限制状态匹配。
def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
strict: bool = True):
# 检查匹配,缺失的,预期有但是没有的,错误信息。
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
# 内置函数,使用_load_from_state_dict导入,隐去了处理的细节。
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(self)
load = None # break load->load reference cycle
# strict的两种模式,如果设置了且存在error就会引发RuntimeError,否则没有。
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(
', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
_named_members
这是后续两大类方法的一个辅助函数,用于生成多种 name + members of modules.
def _named_members(self, get_members_fn, prefix='', recurse=True):
# 由于named member 都是Orderdict类型,所以这里是set类型
memo = set()
# 递归获取所有的module
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
for module_prefix, module in modules:
# 这里 `get_members_fn` 用函数作为变量,根据函数不同动态获取不同的member
members = get_members_fn(module)
# 遍历members对象,构造生成器
for k, v in members:
if v is None or v in memo:
continue
memo.add(v)
name = module_prefix + ('.' if module_prefix else '') + k
yield name, v
named_parameters parameters
named_parameters()
返回遍历所有模型参数的迭代器,生成 名字+参数。
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
# 调用上面的_named_memebers()方法,注意这里传入的`get_members_fn`是个lambda函数
# lambda接受一个module参数,返回他的_parameters内容(在module定义部分可以找到,这是个OrderDict类型变量)
gen = self._named_members(
lambda module: module._parameters.items(),
prefix=prefix, recurse=recurse)
for elem in gen:
yield elem
parameters()
是named_parameters()
的一个简化函数,因为它不需要name属性。
这个函数通常是用来将参数传递给优化器。所以代码也很简洁。
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
for name, param in self.named_parameters(recurse=recurse):
yield param
named_buffers buffers
同上,这里get_members_fn
是 lambda module: module._buffers.item()
.
children modules
children(), named_children()
和modules(), named_modules()
这两类方法比较类似。
named_children()
返回生成直接子模型(即只有下一层)的迭代器,生成 名字 + 子模型。
注意这里并不会递归迭代。
def named_children(self) -> Iterator[Tuple[str, 'Module']]:
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
named_modules()
则是对上面的补充,它会返回模型所有的子模型,即它会递归遍历。
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = ''):
if memo is None:
memo = set()
if self not in memo:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
# 递归调用
for m in module.named_modules(memo, submodule_prefix):
yield m
合理利用这两个函数,我们可以定位到任意我们需要的层级。
train eval
这两个函数源码没有意思,因为只是递归设置了状态。需要再挖深一点才能看到有什么区别。
requiresgrad
用于设置autograd是否需要记录这个组件上参数的操作。
它会设置该组件所有参数的requires_grad
参数。
当我们需要freeze某一个组件的时候十分有用,可以用于finetune或者只训练模型的一部分。
def requires_grad_(self: T, requires_grad: bool = True) -> T:
for p in self.parameters():
p.requires_grad_(requires_grad)
return self
zero_grad
设置组件的所有参数的梯度为0,用于下一步更新之类的。正常训练过程中必须要用的一步骤。
def zero_grad(self, set_to_none: bool = False) -> None:
if getattr(self, '_is_replica', False):
warnings.warn(
"Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
"The parameters are copied (in a differentiable manner) from the original module. "
"This means they are not leaf nodes in autograd and so don't accumulate gradients. "
"If you need gradients in your forward method, consider using autograd.grad instead.")
for p in self.parameters():
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()
以上就是nn.Module
类的一个大概了。