Pytorch中hook
本文最后更新于:2021年1月8日 晚上
hook相当于插件。可以实现一些额外的功能,而又不用修改主体代码。把这些额外功能实现了挂在主代码上,所以叫钩子,很形象。
Pytorch中常见的hook有:
torch.autograd.Variable.register_hook (Python method, in Automatic differentiation package
torch.nn.Module.register_backward_hook (Python method, in torch.nn)
第一个是register_hook,是针对Variable对象的,后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。
由于Pytorch当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉;当我们需要求中间变量的梯度等时,hook应运而生。
🔎 register_hook(hook)
🍥 register_forward_hook(hook)
官方文档解释:
在
module
上注册一个forward hook
。 每次调用forward()
计算输出的时候,这个hook
就会被调用。它应该拥有以下签名:hook(module, input, output) -> None
hook
不应该修改input
和output
的值。 这个函数返回一个 句柄(handle
)。它有一个方法handle.remove()
,可以用这个方法将hook
从module
移除。
看这个解释可能有点蒙逼,但是如果要看一下nn.Module的源码怎么使用hook的话,那就乌云尽散了。
先看 register_forward_hook
# 在此module上注册一个hook
def register_forward_hook(self, hook):
handle = hooks.RemovableHandle(self._forward_hooks)
# 把注册的hook保存在_forward_hooks字典里。
self._forward_hooks[handle.id] = hook
# 返回句柄
return handle
再看 nn.Module 的call方法(被阉割了,只留下需要关注的部分):
def __call__(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
#将注册的hook拿出来用
hook_result = hook(self, input, result)
...
return result
可以看到,当我们执行model(x)
的时候,底层干了以下几件事:
调用
forward
方法计算结果判断有没有注册
forward_hook
,有的话,就将forward
的输入及结果作为hook
的实参。然后让hook自己干一些不可告人的事情。
看到这,我们就明白hook签名的意思了,还有为什么hook不能修改input的output的原因。
小例子:
import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable
def for_hook(module, input, output):
print(module)
for val in input:
print("input val:",val)
for out_val in output:
print("output val:", out_val)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return x+1
model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()
📟 register_backward_hook
官方文档解释:
在
module
上注册一个bachward hook
。每次计算
module
的inputs
的梯度的时候,这个hook
会被调用。hook
应该拥有下面的signature
。hook(module, grad_input, grad_output) -> Variable or None
如果
module
有多个输入输出的话,那么grad_input
grad_output
将会是个tuple
。hook
不应该修改它的arguments
,但是它可以选择性的返回关于输入的梯度,这个返回的梯度在后续的计算中会替代grad_input
。这个函数返回一个 句柄(
handle
)。它有一个方法handle.remove()
,可以用这个方法将hook
从module
移除。
参考链接
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!