Pytorch中hook

本文最后更新于:2021年1月8日 晚上

hook相当于插件。可以实现一些额外的功能,而又不用修改主体代码。把这些额外功能实现了挂在主代码上,所以叫钩子,很形象。

Pytorch中常见的hook有:

第一个是register_hook,是针对Variable对象的,后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。

由于Pytorch当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉;当我们需要求中间变量的梯度等时,hook应运而生。

🔎 register_hook(hook)

🍥 register_forward_hook(hook)

官方文档解释:

module上注册一个forward hook。 每次调用forward()计算输出的时候,这个hook就会被调用。它应该拥有以下签名:

hook(module, input, output) -> None

hook不应该修改 inputoutput的值。 这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hookmodule移除。

看这个解释可能有点蒙逼,但是如果要看一下nn.Module的源码怎么使用hook的话,那就乌云尽散了。
先看 register_forward_hook

# 在此module上注册一个hook
def register_forward_hook(self, hook):
       handle = hooks.RemovableHandle(self._forward_hooks)
       # 把注册的hook保存在_forward_hooks字典里。 
       self._forward_hooks[handle.id] = hook
       # 返回句柄 
       return handle

再看 nn.Module 的call方法(被阉割了,只留下需要关注的部分):

def __call__(self, *input, **kwargs):
   result = self.forward(*input, **kwargs)
   for hook in self._forward_hooks.values():
       #将注册的hook拿出来用
       hook_result = hook(self, input, result)
   ...
   return result

可以看到,当我们执行model(x)的时候,底层干了以下几件事:

  • 调用 forward 方法计算结果

  • 判断有没有注册 forward_hook,有的话,就将 forward 的输入及结果作为hook的实参。然后让hook自己干一些不可告人的事情。

看到这,我们就明白hook签名的意思了,还有为什么hook不能修改input的output的原因。

小例子:

import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable

def for_hook(module, input, output):
    print(module)
    for val in input:
        print("input val:",val)
    for out_val in output:
        print("output val:", out_val)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
    def forward(self, x):

        return x+1

model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()

📟 register_backward_hook

官方文档解释:

module上注册一个bachward hook

每次计算moduleinputs的梯度的时候,这个hook会被调用。hook应该拥有下面的signature

hook(module, grad_input, grad_output) -> Variable or None

如果module有多个输入输出的话,那么grad_input grad_output将会是个tuplehook不应该修改它的arguments,但是它可以选择性的返回关于输入的梯度,这个返回的梯度在后续的计算中会替代grad_input

这个函数返回一个 句柄(handle)。它有一个方法 handle.remove(),可以用这个方法将hookmodule移除。

参考链接

【1】https://www.zhihu.com/question/61044004/answer/183682138

【2】pytorch 的 hook 机制