hook
函数可以在你不修改模型代码的同时,帮助你提取(或修改)中间层的参数或者特征图。
最近在更新PAF的时候,添加了Projector (文档Embedding)功能,需要将网络靠近输出的一维特征取出来做PCA/tsne分析(辅助分析虽然看起来好像没什么用,但是用来装13挺合适的)。问题在于取出对应层embedding的方法。因为训练测试需要完整将输入通过网络得到输出,尽管我可以在网络定义的类中重新写一个函数使得输入只执行到我想要的那层,然后返回给我embedding。这是个笨方法,而且效率很低。所以想起Grad-CAM实现的时候,用到的Hook函数。
什么是Hook
字面上是“钩子”的意思。通俗地讲就是插件,在不用修改主体代码的条件下,可以实现一些额外的功能。把这些额外实现的功能“钩”在主代码上,所以叫钩子。
对于我的需求,就可以通过先定义一个Hook类,然后对应层注册一下,输入执行到该层的时候hook一下,把输出给保存下来,就避免了二次执行了。
更通俗一点的需求呢?在知乎看到李斌的高赞回答里举出例子就非常好。可以移步看一下。
我对他这个例子的总结是,基于PyTorch的autograd
机制,它构建的图中,一旦完成backward
的使命,图就自动回收了。因此如果想取出中间变量的梯度,或者对中间变量的梯度进行限制,就可以通过定义一个中间变量的hook,打印grad的值或者修改等。
PyTorch 相关Hook函数解读
PyTorch中与hook
相关的函数有(可以直接看官方文档解释,也可以留在这里看,我们一起讨论):
torch.Tensor.register_hook
torch.nn.Module.register_forward_hook
torch.nn.Module.register_backward_hook
torch.nn.Module.register_forward_pre_hook
先声明一下,这些注册函数的共同点是:都会返回一个handle
手柄(类似开关),它有handle.remoev()
方法,可以将hook
移除。
torch.Tensor.register_hook
针对torch.Tensor
,注册一个backward
的钩子,每次计算这个Tensor的梯度的时候,都会触发钩子,调用钩子函数。
这个hook
函数长这个样子:
hook(grad) -> Tensor or None
hook
函数不应该改变它的参数grad
,但可以选择返回一个新的参数来替换原来的grad
,也达到修改的目的。
下面这个例子在官方的例子上进行了一点修改。
# double the grad
def hook(grad):
return grad*2
v = torch.tensor([0., 0., 0.], requires_grad=True) # set requires_grad to True
h = v.register_hook(hook) # remember h is a handle
v.backward(torch.tensor([1., 2., 3.])) # torch.Tensor.backward(gradient=None, ...)
v.grad # tensor([2., 4., 6.])
h.remove() # remove the hook
可以自己修改hook
函数体会一下。
torch.nn.Module.register_forward_hook
在torch.nn.Module
上挂一个前向传播的钩子函数。
每次forward()
计算得到输出之后就会调用hook
,它的形式定义如下:
hook(module, input, output) -> None
这里的hook不会对输入或者输出做改变,只取不改。
torch.nn.Module.register_backward_hook
针对torch.nn.Module
,在它上面注册一个钩子函数。在module
计算完成的时候,调用钩子函数。
hook
函数的形式定义:
hook(module, input, output) -> Tensor or None
注意点:
- 因为
backward
是从后到前的,它的参数顺序和前向传播的相比应该是倒过来的。 - 对于前向传播,
layer2
的前一层是layer1
;对于后向传播,layer2
的前一层是layer3
。 - 模型的
output
是前向最后一层的output
所以文档里说明如果有返回值,替换的是上述定义中的input
。
这一点有些混乱,我们明确一个问题,如果修改返回值应该“修改”谁。
将hook
换个表达:
hook(module, grad_out, grad_in) -> Tensor or None
- grad_in: 模型输出对于对应层输出的梯度 # 和forward pass统一
- 等于一个表示当前层每一个神经元
error
的tensor
(等于模型输出对于层输出的梯度 or 等于它应该提升多大程度) - 对于最后一层,例如
[1,1] <=>
模型输出对于它本身的梯度 - 同样可以被当作是一个
weight map
,比如[1, 0]
关掉了第二个梯度,[2,1]
在第一个梯度上放了双倍的权重
- 等于一个表示当前层每一个神经元
- grad_out: grad_in * (模型输出对于输出的梯度)
- 等于下一层的error (链式法则)
- 等于下一层的error (链式法则)
所以如果有返回的话,返回的是下一层的error
,对应我下面这里的grad_out
。
例子建议去Understanding Pytorch hooks 查看,需要投入时间。
torch.nn.Module.register_forward_pre_hook
在torch.nn.Module
上注册一个前向传播的预先钩函数。
和register_forward_hook
相比,它是每次forward()
计算之前就被调用,它的形式定义如下:
hook(module, input) -> None
看返回值就知道,它也不能对input
进行修改。
总结
- Hook函数不改变传递进来的gradients
- 但是一旦调用
return
,返回的值将作为输出梯度。