0%

PyTorch 中的 hook

hook 是程序设计领域一个概念,引用一个博客的说法:

In general, “hooks” are functions that automatically execute after a particular event. 1

1. What are hooks?

Hook 是 PyTorch 中一个十分有用的特性。利用它,我们可以不必改变网络输入输出的结构,方便地获取、改变网络中间层变量的值和梯度。这个功能被广泛用于可视化神经网络中间层的 feature、gradient,从而诊断神经网络中可能出现的问题,分析网络有效性。2

2. Hook 是 PyTorch 中一个十分有用的特性

在 PyTorch 中,hook 可用于 Tensornn.Module 中。

1. Hook For Tensor

在 PyTorch 的计算图(computation graph)中,只有叶子结点(leaf nodes)的变量会保留梯度。而所有中间变量的梯度只被用于反向传播,一旦完成反向传播,中间变量的梯度就将自动释放,从而节约内存。举个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

w = torch.tensor(4.0, requires_grad=True)

z = x + y
o = w * z

o.backward()

print(x.grad)
print(y.grad)
print(w.grad)
print(z.grad)
print(o.grad)

"""
output:
tensor(4.)
tensor(4.)
tensor(5.)
None
None
"""

其中只有 x, y, w 三个叶子变量的梯度在 backward 操作之后还能访问到,zo 是中间变量(o 是最终结果变量,但也不是叶子结点,在 backward 这个操作上算是中间变量),不保留梯度,输出它们的 grad 属性是 None

如果需要保存中间结点的梯度,则需要对中间变量使用 retain_grad() 方法,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

w = torch.tensor(4.0, requires_grad=True)

z = x + y
z.retain_grad()
o = w * z
o.retain_grad()

o.backward()

print(x.grad)
print(y.grad)
print(w.grad)
print(z.grad)
print(o.grad)

"""
output:
tensor(4.)
tensor(4.)
tensor(5.)
tensor(4.)
tensor(1.)
"""

这样中间变量的梯度在反向传播之后也可以保存并被访问到。

但是,这种加 retain_grad() 的方案会增加内存占用,并不是个好办法,对此的一种替代方案,就是用 hook 保存中间变量的梯度。

对于中间变量 z,hook 的使用方式为:z.register_hook(hook_fn),其中 hook_fn为一个用户自定义的函数,其签名为:

1
hook_fn(grad) -> Tensor or None

它的输入为变量 z 的梯度(这是固定的),输出为一个 Tensor 或者是 NoneNone 一般用于直接打印梯度)。反向传播时,梯度传播到变量 z,再继续向前传播之前,将会传入 hook_fn如果 hook_fn 的返回值是 None,那么梯度将不改变,继续向前传播;如果 hook_fn 的返回值是 Tensor 类型,则该 Tensor 将取代 z 原有的梯度,向前传播。

下面的示例代码中 hook_fn 不改变梯度值,仅仅是打印梯度:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

w = torch.tensor(4.0, requires_grad=True)

z = x + y

def hook_fn(grad):
print("z grad:", grad)

z.register_hook(hook_fn)

o = w * z

print('=====Start backprop=====')
o.backward()
print('=====End backprop=====')

"""
运行到这里会产生输出:
=====Start backprop=====
z grad: tensor(4.)
=====End backprop=====
"""

print(x.grad)
print(y.grad)
print(w.grad)
print(z.grad)
print(o.grad)

"""
运行到这里会产生输出:
tensor(4.)
tensor(4.)
tensor(5.)
None
None
"""

z 绑定了 hook_fn 后,梯度反向传播时将会打印出 oz 的偏导,和上文中 z.retain_grad() 方法得到的 z 的偏导一致。

如果在 hook_fn 中改变梯度值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

w = torch.tensor(4.0, requires_grad=True)

z = x + y

def hook_fn(grad):
g = 2 * grad
print("z grad multiplies 2:", g)
return g

z.register_hook(hook_fn)

o = w * z

print('=====Start backprop=====')
o.backward()
print('=====End backprop=====')

"""
运行到这里会产生输出:
=====Start backprop=====
z grad multiplies 2: tensor(8.)
=====End backprop=====
"""

print(x.grad)
print(y.grad)
print(w.grad)
print(z.grad)
print(o.grad)

"""
运行到这里会产生输出:
tensor(8.)
tensor(8.)
tensor(5.)
None
None
"""

发现 z 的梯度变为两倍后,受其影响,xy 的梯度也都变成了原来的两倍。

上面说到,若要改变梯度,则 hook_fn 必须返回一个 Tensor,如果仅仅是在 hook_fn 中改变参数 grad 的值,并不会改变变量原有的梯度,更不会影响回传的梯度,把上面的例子做一下改动:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)

w = torch.tensor(4.0, requires_grad=True)

z = x + y

def hook_fn(grad):
grad = grad * 2

z.register_hook(hook_fn)
z.register_hook(lambda grad: print("z grad", grad))

o = w * z

print('=====Start backprop=====')
o.backward()
print('=====End backprop=====')

"""
运行到这里会产生输出:
=====Start backprop=====
z grad tensor(4.)
=====End backprop=====
"""

print(x.grad)
print(y.grad)
print(w.grad)
print(z.grad)
print(o.grad)

"""
运行到这里会产生输出:
tensor(4.)
tensor(4.)
tensor(5.)
None
None
"""

上面这个例子中,hook_fn 函数虽然对 grad 做了个 inplace 赋值操作,但因为没有返回一个 Tensor,所以并没有真正改变 z 的梯度,从 lambda 函数的输出也可以看出,z 的梯度并没有变成 2 倍,同样 xy 的梯度也没有变化。

对于一个 Tensorregister_hook 函数可以被调用多次,每次调用的 hook_fn 函数被保存在 tensor._backward_hooks,这是 一个 OrderedDict 对象中,下面举一个例子看:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b

def c_hook(grad):
print(grad)
return grad + 2

# c 在这里注册了两个 hook
c.register_hook(c_hook)
c.register_hook(lambda grad: print(grad))
# c 是中间变量,需要用 retain_grad 留住梯度,否则后面会被清除
c.retain_grad()

print(c._backward_hooks)
"""
output:
OrderedDict([(43, <function __main__.c_hook(grad)>),
(44, <function __main__.<lambda>(grad)>),
(45,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>)])
"""

d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad + 100)

e = c * d

# e 也是中间变量,为了留住梯度也用了 retain_grad
e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()

print(e._backward_hooks)
"""
retain_grad 可以被调用多次,但在 OrderedDict 中只保存一个
output:
OrderedDict([(47,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>),
(48, <function __main__.<lambda>(grad)>)])
"""

e.backward()
"""
这句会产生输出,这个输出是 c 的两个 hook 里的 print 引起的
output:
tensor(8.)
tensor(10.)
"""

print(e.grad)
print(d.grad)
print(c.grad)
print(a.grad)
print(b.grad)

"""
output:
tensor(1.)
tensor(112.)
tensor(10.)
tensor(30.)
tensor(20.)
"""

注册的 hook 函数可以被 remove,这个 remove 过程如果在 backward 之前,那么不会对变量的梯度产生作用,如果发生在 backward 之后,那么还是会对变量的梯度产生作用。

先来个 remove 发生在 backward 之前的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b

def c_hook(grad):
print(grad)
return grad + 2

# c 在这里注册了两个 hook,将第一个 hook 赋值给一个变量做标记,用于后面 remove
h = c.register_hook(c_hook)
c.register_hook(lambda grad: print(grad))
c.retain_grad()

print(c._backward_hooks)
"""
output:
OrderedDict([(43, <function __main__.c_hook(grad)>),
(44, <function __main__.<lambda>(grad)>),
(45,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>)])
"""

d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad + 100)

e = c * d

e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()

print(e._backward_hooks)
"""
retain_grad 可以被调用多次,但在 OrderedDict 中只保存一个
output:
OrderedDict([(47,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>),
(48, <function __main__.<lambda>(grad)>)])
"""

# 在 backward 之前 remove c 的 hook
h.remove()
print(c._backward_hooks)
"""
c 的第一个 hook 没了
output:
OrderedDict([(68, <function __main__.<lambda>(grad)>),
(69,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>)])
"""

e.backward()
"""
这时候只有一个输出了,注意这个 tensor(8.)
是由 c.register_hook(lambda grad: print(grad)) 引起的,
因为前一个 hook 被 remove 了,不存在 grad + 2 这个过程
output:
tensor(8.)
"""

print(e.grad)
print(d.grad)
print(c.grad)
print(a.grad)
print(b.grad)

"""
c 的梯度就是 8,同时 a 和 b 的梯度也会变化
output:
tensor(1.)
tensor(112.)
tensor(8.)
tensor(24.)
tensor(16.)
"""

再来个 remove 发生在 backward 之后的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch

a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)

c = a * b

def c_hook(grad):
print(grad)
return grad + 2

# c 在这里注册了两个 hook,将第一个 hook 赋值给一个变量做标记,用于后面 remove
h = c.register_hook(c_hook)
c.register_hook(lambda grad: print(grad))
c.retain_grad()

print(c._backward_hooks)
"""
output:
OrderedDict([(43, <function __main__.c_hook(grad)>),
(44, <function __main__.<lambda>(grad)>),
(45,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>)])
"""

d = torch.tensor(4.0, requires_grad=True)
d.register_hook(lambda grad: grad + 100)

e = c * d

e.retain_grad()
e.register_hook(lambda grad: grad * 2)
e.retain_grad()

print(e._backward_hooks)
"""
retain_grad 可以被调用多次,但在 OrderedDict 中只保存一个
output:
OrderedDict([(47,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>),
(48, <function __main__.<lambda>(grad)>)])
"""

e.backward()
"""
到这里还没有发生 remove
output:
tensor(8.)
tensor(10.)
"""

# 在 backward 之后 remove c 的 hook,这时对于梯度变化已经没有影响了
h.remove()
print(c._backward_hooks)
"""
c 的第一个 hook 没了
output:
OrderedDict([(74, <function __main__.<lambda>(grad)>),
(75,
<function torch._tensor.Tensor.retain_grad.<locals>.retain_grad_hook(grad)>)])
"""

print(e.grad)
print(d.grad)
print(c.grad)
print(a.grad)
print(b.grad)

"""
梯度不因 remove 操作发生变化
output:
tensor(1.)
tensor(112.)
tensor(10.)
tensor(30.)
tensor(20.)
"""

针对 Tensor 的 hook 的使用场景一般不多,最常用的 hook 是针对 nn.Module 的。

2. Hook for nn.Module

说明:这部分内容基本上来自 半小时学会 PyTorch Hook,其中加入了部分自己的分析。

网络模块 nn.module 不像上一节中的 Tensor,拥有显式的变量名可以直接访问,而是被封装在神经网络中间。我们通常只能获得网络整体的输入和输出,对于夹在网络中间的模块,我们不但很难得知它输入/输出的梯度,甚至连它输入输出的数值都无法获得。除非设计网络时,在 forward 函数的返回值中包含中间 module 的输出,或者用很麻烦的办法,把网络按照 module 的名称拆分再组合,让中间层提取的 feature 暴露出来。

为了解决这个麻烦,PyTorch 设计了两种 hook:register_forward_hookregister_backward_hook,分别用来获取正/反向传播时,中间层模块输入和输出的 feature/gradient,大大降低了获取模型内部信息流的难度。

2.1. register forward hook

register_forward_hook 的作用是获取前向传播过程中,各个网络模块的输入和输出。对于模块 module,其使用方式为:module.register_forward_hook(hook_fn)。其中 hook_fn 的签名为:

1
hook_fn(module, input, output) -> None

它的输入变量分别为:模块,模块的输入,模块的输出。和对 Tensor 的 hook 不同,forward hook 不返回任何值,也就是说不能用来修改输入或者输出的值,但借助这个 hook,我们可以方便地用预训练的神经网络提取特征,而不用改变预训练网络的结构。

register_forward_hook 的源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def register_forward_hook(self, hook):
r"""Registers a forward hook on the module.
The hook will be called every time after :func:`forward` has computed an output.
It should have the following signature::
hook(module, input, output) -> None or modified output
The hook can modify the output. It can modify the input inplace but
it will not have effect on forward since this is called after
:func:`forward` is called.

Returns:
:class:`torch.utils.hooks.RemovableHandle`:
a handle that can be used to remove the added hook by calling
``handle.remove()``
"""
handle = hooks.RemovableHandle(self._forward_hooks)
self._forward_hooks[handle.id] = hook
return handle

从这个源码注释中我们可以获得一些信息:

  • hook 函数可以改变 output 的值,只要返回值不是 None,这个返回值就是修改后的 output 值。这个与 Tensor 中的 hook 是一样的
  • 可以 inplace 地改变 input 的值,但对 forward 过程不会有影响,因为 hook 是在 forward 函数之后被调用的
  • 因为 hook 是在 forward 函数之后被调用的,所以 register_forward_hook() 函数必须在 forward() 函数调用之前被使用,如果在 forward() 函数之后再执行 register_forward_hook() ,被 register 的 hook 函数已经不能对 forward 过程产生影响或者获取 forward 过程的中间值了。

下面提供一段示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn

# 首先我们定义一个模型
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(4, 1)
self.initialize()

# 为了方便验证,我们将指定特殊的weight和bias
def initialize(self):
with torch.no_grad():
self.fc1.weight = torch.nn.Parameter(
torch.Tensor([[1., 2., 3.],
[-4., -5., -6.],
[7., 8., 9.],
[-10., -11., -12.]]))

self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0]))

def forward(self, x):
o = self.fc1(x)
o = self.relu1(o)
o = self.fc2(o)
return o

# 全局变量,用于存储中间层的 feature
total_feat_out = []
total_feat_in = []

# 定义 forward hook function
def hook_fn_forward(module, input, output):
print(module) # 用于区分模块
print('input: ', input) # 首先打印出来
print('output: ', output)
total_feat_out.append(output) # 然后分别存入全局 list 中
total_feat_in.append(input)


model = Model()

# 在这里注册 hook,是在 forward() 函数之前,因为 forward() 是在 o = model(x) 这一步执行的
modules = model.named_children() #
for name, module in modules:
module.register_forward_hook(hook_fn_forward)

# 注意下面代码中 x 的维度,对于linear module,输入一定是大于等于二维的
# (第一维是 batch size)。在 forward hook 中看不出来,但是 backward hook 中,
# 得到的梯度完全不对。
# 有一篇 hook 的教程就是这里出了错,作者还强行解释,遗毒无穷,

x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_()
o = model(x)
o.backward()

print('==========Saved inputs and outputs==========')
for idx in range(len(total_feat_in)):
print('input: ', total_feat_in[idx])
print('output: ', total_feat_out[idx])

输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Linear(in_features=3, out_features=4, bias=True)
input: (tensor([[1., 1., 1.]], requires_grad=True),)
output: tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>)
ReLU()
input: (tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>),)
output: tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>)
Linear(in_features=4, out_features=1, bias=True)
input: (tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>),)
output: tensor([[89.]], grad_fn=<AddmmBackward>)
==========Saved inputs and outputs==========
input: (tensor([[1., 1., 1.]], requires_grad=True),)
output: tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>)
input: (tensor([[ 7., -13., 27., -29.]], grad_fn=<AddmmBackward>),)
output: tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>)
input: (tensor([[ 7., 0., 27., 0.]], grad_fn=<ReluBackward0>),)
output: tensor([[89.]], grad_fn=<AddmmBackward>)

2.2. register backward hook

register_forward_hook 相似,register_backward_hook 的作用是获取神经网络反向传播过程中,各个模块输入端和输出端的梯度值。对于模块 module,其使用方式为:module.register_backward_hook(hook_fn)。其中 hook_fn 的函数签名为:

1
hook_fn(module, grad_input, grad_output) -> Tensor or None

它的输入变量分别为:模块,模块输入端的梯度,模块输出端的梯度。需要注意的是,这里的输入端和输出端,是站在前向传播的角度的,而不是反向传播的角度。例如线性模块:o = W * x + b,其输入端为 Wxb,输出端为 o

如果模块有多个输入或者输出的话,grad_inputgrad_output 可以是 tuple 类型。对于线性模块:o = W * x + b,它的输入端包括了 Wxb 三部分,因此 grad_input 就是一个包含三个元素的 tuple。

这里注意和 forward hook 的不同:

1.在 forward hook 中,input 是 x,而不包括 Wb

2.返回 Tensor 或者 None,backward hook 函数不能直接改变它的输入变量,但是可以返回新的 grad_input,反向传播到它上一个模块。

举个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(3, 4)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(4, 1)
self.initialize()

def initialize(self):
with torch.no_grad():
self.fc1.weight = torch.nn.Parameter(
torch.Tensor([[1., 2., 3.],
[-4., -5., -6.],
[7., 8., 9.],
[-10., -11., -12.]]))

self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0]))

def forward(self, x):
o = self.fc1(x)
o = self.relu1(o)
o = self.fc2(o)
return o


total_grad_out = []
total_grad_in = []


def hook_fn_backward(module, grad_input, grad_output):
print(module) # 为了区分模块
# 为了符合反向传播的顺序,我们先打印 grad_output
print('grad_output', grad_output)
# 再打印 grad_input
print('grad_input', grad_input)
# 保存到全局变量
total_grad_in.append(grad_input)
total_grad_out.append(grad_output)


model = Model()

modules = model.named_children()
for name, module in modules:
module.register_backward_hook(hook_fn_backward)

# 这里的 requires_grad 很重要,如果不加,backward hook
# 执行到第一层,对 x 的导数将为 None,某英文博客作者这里疏忽了
# 此外再强调一遍 x 的维度,一定不能写成 torch.Tensor([1.0, 1.0, 1.0]).requires_grad_()
# 否则 backward hook 会出问题。
x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_()
o = model(x)
o.backward()

print('==========Saved inputs and outputs==========')
for idx in range(len(total_grad_in)):
print('grad output: ', total_grad_out[idx])
print('grad input: ', total_grad_in[idx])

输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
Linear(in_features=4, out_features=1, bias=True)
grad_output (tensor([[1.]]),)
grad_input (tensor([1.]), tensor([[1., 2., 3., 4.]]), tensor([[ 7.],
[ 0.],
[27.],
[ 0.]]))
ReLU()
grad_output (tensor([[1., 2., 3., 4.]]),)
grad_input (tensor([[1., 0., 3., 0.]]),)
Linear(in_features=3, out_features=4, bias=True)
grad_output (tensor([[1., 0., 3., 0.]]),)
grad_input (tensor([1., 0., 3., 0.]), tensor([[22., 26., 30.]]), tensor([[1., 0., 3., 0.],
[1., 0., 3., 0.],
[1., 0., 3., 0.]]))
==========Saved inputs and outputs==========
grad output: (tensor([[1.]]),)
grad input: (tensor([1.]), tensor([[1., 2., 3., 4.]]), tensor([[ 7.],
[ 0.],
[27.],
[ 0.]]))
grad output: (tensor([[1., 2., 3., 4.]]),)
grad input: (tensor([[1., 0., 3., 0.]]),)
grad output: (tensor([[1., 0., 3., 0.]]),)
grad input: (tensor([1., 0., 3., 0.]), tensor([[22., 26., 30.]]), tensor([[1., 0., 3., 0.],
[1., 0., 3., 0.],
[1., 0., 3., 0.]]))

需要注意的是,对线性模块,其 grad_input 是一个三元组,排列顺序分别为:对偏置 b 的导数,对输入 x 的导数,对权重 W 的导数。

backward hook 在全连接层和卷积层表现不一致的地方:

  1. 形状:
    • 在卷积层中,weight 的梯度和 weight 的形状相同
    • 在全连接层中,weight 的梯度的形状是 weight 形状的转秩
  2. grad_input tuple 中各梯度的顺序
    • 在卷积层中,bias 的梯度位于tuple 的末尾:grad_input = (对 feature 的导数,对权重 W 的导数,对 bias 的导数)
    • 在全连接层中,bias 的梯度位于 tuple 的开头:grad_input = (对 bias 的导数,对 feature 的导数,对 W 的导数)
  3. 当 batch size > 1时,对 bias 的梯度处理不同
    • 在卷积层,对 bias 的梯度为整个 batch 的数据在 bias 上的梯度之和:grad_input = (对feature的导数,对权重 W 的导数,对 bias 的导数)
    • 在全连接层,对 bias 的梯度是分开的,bach 中每条数据,对应一个 bias 的梯度:grad_input = ((data1 对 bias 的导数,data2 对 bias 的导数 …),对 feature 的导数,对 W 的导数)

3. 示例

示例可以参见 Guided Backpropagation 示例

4. 参考资料

[1]. 半小时学会 PyTorch Hook

[2]. 一个不错的视频:bilibili, YouTube

[3]. pytorch的hook机制之register_forward_hook