Pytorch中hook
本文最后更新于:2021年1月8日 晚上
hook相当于插件。可以实现一些额外的功能,而又不用修改主体代码。把这些额外功能实现了挂在主代码上,所以叫钩子,很形象。
Pytorch中常见的hook有:
torch.autograd.Variable.register_hook (Python method, in Automatic differentiation package
torch.nn.Module.register_backward_hook (Python method, in torch.nn)
第一个是register_hook,是针对Variable对象的,后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。
由于Pytorch当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉;当我们需要求中间变量的梯度等时,hook应运而生。
🔎 register_hook(hook)
🍥 register_forward_hook(hook)
官方文档解释:
在
module上注册一个forward hook。 每次调用forward()计算输出的时候,这个hook就会被调用。它应该拥有以下签名:hook(module, input, output) -> None
hook不应该修改input和output的值。 这个函数返回一个 句柄(handle)。它有一个方法handle.remove(),可以用这个方法将hook从module移除。
看这个解释可能有点蒙逼,但是如果要看一下nn.Module的源码怎么使用hook的话,那就乌云尽散了。
先看 register_forward_hook
# 在此module上注册一个hook
def register_forward_hook(self, hook):
handle = hooks.RemovableHandle(self._forward_hooks)
# 把注册的hook保存在_forward_hooks字典里。
self._forward_hooks[handle.id] = hook
# 返回句柄
return handle
再看 nn.Module 的call方法(被阉割了,只留下需要关注的部分):
def __call__(self, *input, **kwargs):
result = self.forward(*input, **kwargs)
for hook in self._forward_hooks.values():
#将注册的hook拿出来用
hook_result = hook(self, input, result)
...
return result
可以看到,当我们执行model(x)的时候,底层干了以下几件事:
调用
forward方法计算结果判断有没有注册
forward_hook,有的话,就将forward的输入及结果作为hook的实参。然后让hook自己干一些不可告人的事情。
看到这,我们就明白hook签名的意思了,还有为什么hook不能修改input的output的原因。
小例子:
import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable
def for_hook(module, input, output):
print(module)
for val in input:
print("input val:",val)
for out_val in output:
print("output val:", out_val)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return x+1
model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()
📟 register_backward_hook
官方文档解释:
在
module上注册一个bachward hook。每次计算
module的inputs的梯度的时候,这个hook会被调用。hook应该拥有下面的signature。hook(module, grad_input, grad_output) -> Variable or None如果
module有多个输入输出的话,那么grad_inputgrad_output将会是个tuple。hook不应该修改它的arguments,但是它可以选择性的返回关于输入的梯度,这个返回的梯度在后续的计算中会替代grad_input。这个函数返回一个 句柄(
handle)。它有一个方法handle.remove(),可以用这个方法将hook从module移除。
参考链接
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!