hook 是程序设计领域一个概念,引用一个博客的说法:
In general, “hooks” are functions that automatically execute after a particular event. 1
1. What are hooks? ↩
Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。2
2. Hook 是 PyTorch 中一个十分有用的特性 ↩
在 PyTorch 中,hook 可用于 Tensor
或 nn.Module
中。
1. Hook For Tensor
在 PyTorch 的计算图(computation graph)中,只有叶子结点(leaf nodes)的变量会保留梯度。而所有中间变量的梯度只被用于反向传播,一旦完成反向传播,中间变量的梯度就将自动释放,从而节约内存。举个例子:
1 | import torch |
其中只有 x
, y
, w
三个叶子变量的梯度在 backward
操作之后还能访问到,z
和 o
是中间变量(o
是最终结果变量,但也不是叶子结点,在 backward
这个操作上算是中间变量),不保留梯度,输出它们的 grad
属性是 None
。
如果需要保存中间结点的梯度,则需要对中间变量使用 retain_grad()
方法,如下:
1 | import torch |
这样中间变量的梯度在反向传播之后也可以保存并被访问到。
但是,这种加 retain_grad()
的方案会增加内存占用,并不是个好办法,对此的一种替代方案,就是用 hook 保存中间变量的梯度。
对于中间变量 z
,hook 的使用方式为:z.register_hook(hook_fn)
,其中 hook_fn
为一个用户自定义的函数,其签名为:
1 | hook_fn(grad) -> Tensor or None |
它的输入为变量 z 的梯度(这是固定的),输出为一个 Tensor
或者是 None
(None
一般用于直接打印梯度)。反向传播时,梯度传播到变量 z
,再继续向前传播之前,将会传入 hook_fn
。如果 hook_fn
的返回值是 None
,那么梯度将不改变,继续向前传播;如果 hook_fn
的返回值是 Tensor
类型,则该 Tensor
将取代 z
原有的梯度,向前传播。
下面的示例代码中 hook_fn
不改变梯度值,仅仅是打印梯度:
1 | import torch |
z
绑定了 hook_fn
后,梯度反向传播时将会打印出 o
对 z
的偏导,和上文中 z.retain_grad()
方法得到的 z
的偏导一致。
如果在 hook_fn
中改变梯度值:
1 | import torch |
发现 z
的梯度变为两倍后,受其影响,x
和 y
的梯度也都变成了原来的两倍。
上面说到,若要改变梯度,则 hook_fn
必须返回一个 Tensor
,如果仅仅是在 hook_fn
中改变参数 grad
的值,并不会改变变量原有的梯度,更不会影响回传的梯度,把上面的例子做一下改动:
1 | import torch |
上面这个例子中,hook_fn
函数虽然对 grad
做了个 inplace 赋值操作,但因为没有返回一个 Tensor
,所以并没有真正改变 z
的梯度,从 lambda 函数的输出也可以看出,z
的梯度并没有变成 2 倍,同样 x
和 y
的梯度也没有变化。
对于一个 Tensor
,register_hook
函数可以被调用多次,每次调用的 hook_fn
函数被保存在 tensor._backward_hooks
,这是 一个 OrderedDict
对象中,下面举一个例子看:
1 | import torch |
注册的 hook 函数可以被 remove,这个 remove 过程如果在 backward 之前,那么不会对变量的梯度产生作用,如果发生在 backward 之后,那么还是会对变量的梯度产生作用。
先来个 remove 发生在 backward 之前的:
1 | import torch |
再来个 remove 发生在 backward 之后的:
1 | import torch |
针对 Tensor
的 hook 的使用场景一般不多,最常用的 hook 是针对 nn.Module
的。
2. Hook for nn.Module
说明:这部分内容基本上来自 半小时学会 PyTorch Hook,其中加入了部分自己的分析。
网络模块 nn.module
不像上一节中的 Tensor
,拥有显式的变量名可以直接访问,而是被封装在神经网络中间。我们通常只能获得网络整体的输入和输出,对于夹在网络中间的模块,我们不但很难得知它输入/输出的梯度,甚至连它输入输出的数值都无法获得。除非设计网络时,在 forward
函数的返回值中包含中间 module 的输出,或者用很麻烦的办法,把网络按照 module 的名称拆分再组合,让中间层提取的 feature 暴露出来。
为了解决这个麻烦,PyTorch 设计了两种 hook:register_forward_hook
和 register_backward_hook
,分别用来获取正/反向传播时,中间层模块输入和输出的 feature/gradient,大大降低了获取模型内部信息流的难度。
2.1. register forward hook
register_forward_hook
的作用是获取前向传播过程中,各个网络模块的输入和输出。对于模块 module
,其使用方式为:module.register_forward_hook(hook_fn)
。其中 hook_fn
的签名为:
1 | hook_fn(module, input, output) -> None |
它的输入变量分别为:模块,模块的输入,模块的输出。和对 借助这个 hook,我们可以方便地用预训练的神经网络提取特征,而不用改变预训练网络的结构。Tensor
的 hook 不同,forward hook 不返回任何值,也就是说不能用来修改输入或者输出的值,但
register_forward_hook
的源码如下:
1 | def register_forward_hook(self, hook): |
从这个源码注释中我们可以获得一些信息:
- hook 函数可以改变
output
的值,只要返回值不是None
,这个返回值就是修改后的output
值。这个与Tensor
中的 hook 是一样的 - 可以 inplace 地改变
input
的值,但对 forward 过程不会有影响,因为 hook 是在forward
函数之后被调用的 - 因为 hook 是在
forward
函数之后被调用的,所以register_forward_hook()
函数必须在forward()
函数调用之前被使用,如果在forward()
函数之后再执行register_forward_hook()
,被 register 的 hook 函数已经不能对 forward 过程产生影响或者获取 forward 过程的中间值了。
下面提供一段示例代码:
1 | import torch |
输出如下:
1 | Linear(in_features=3, out_features=4, bias=True) |
2.2. register backward hook
和 register_forward_hook
相似,register_backward_hook
的作用是获取神经网络反向传播过程中,各个模块输入端和输出端的梯度值。对于模块 module
,其使用方式为:module.register_backward_hook(hook_fn)
。其中 hook_fn
的函数签名为:
1 | hook_fn(module, grad_input, grad_output) -> Tensor or None |
它的输入变量分别为:模块,模块输入端的梯度,模块输出端的梯度。需要注意的是,这里的输入端和输出端,是站在前向传播的角度的,而不是反向传播的角度。例如线性模块:o = W * x + b
,其输入端为 W
,x
和 b
,输出端为 o
。
如果模块有多个输入或者输出的话,grad_input
和 grad_output
可以是 tuple 类型。对于线性模块:o = W * x + b
,它的输入端包括了 W
、x
和 b
三部分,因此 grad_input
就是一个包含三个元素的 tuple。
这里注意和 forward hook 的不同:
1.在 forward hook 中,input 是 x
,而不包括 W
和 b
。
2.返回 Tensor
或者 None
,backward hook 函数不能直接改变它的输入变量,但是可以返回新的 grad_input
,反向传播到它上一个模块。
举个例子:
1 | import torch |
输出如下:
1 | Linear(in_features=4, out_features=1, bias=True) |
需要注意的是,对线性模块,其 grad_input
是一个三元组,排列顺序分别为:对偏置 b
的导数,对输入 x
的导数,对权重 W
的导数。
backward hook 在全连接层和卷积层表现不一致的地方:
- 形状:
- 在卷积层中,weight 的梯度和 weight 的形状相同
- 在全连接层中,weight 的梯度的形状是 weight 形状的转秩
grad_input
tuple 中各梯度的顺序- 在卷积层中,bias 的梯度位于tuple 的末尾:
grad_input
= (对 feature 的导数,对权重W
的导数,对 bias 的导数) - 在全连接层中,bias 的梯度位于 tuple 的开头:
grad_input
= (对 bias 的导数,对 feature 的导数,对W
的导数)
- 在卷积层中,bias 的梯度位于tuple 的末尾:
- 当 batch size > 1时,对 bias 的梯度处理不同
- 在卷积层,对 bias 的梯度为整个 batch 的数据在 bias 上的梯度之和:
grad_input
= (对feature的导数,对权重 W 的导数,对 bias 的导数) - 在全连接层,对 bias 的梯度是分开的,bach 中每条数据,对应一个 bias 的梯度:
grad_input
= ((data1 对 bias 的导数,data2 对 bias 的导数 …),对 feature 的导数,对 W 的导数)
- 在卷积层,对 bias 的梯度为整个 batch 的数据在 bias 上的梯度之和:
3. 示例
示例可以参见 Guided Backpropagation 示例
4. 参考资料
[1]. 半小时学会 PyTorch Hook