如何优雅的提取feature map和gradient map(2个简单例子理解pytorch中的hook的使用)
AI算法与图像处理
共 5398字,需浏览 11分钟
·
2022-06-17 21:19
简介
我们在训练网络的过程中往往希望知道在forward过程中每层的输入和输出,也想知道backward过程中反传的梯度。
pytorch提供了两个钩子注册函数(register_forward_hook,register_full_backward_hook),用于获取forward和backward中输入和输出,可以在不改变网络的定义代码,不需要在forward函数中return某个感兴趣层的输出,可以获得特定层或者特定block的输入输出以及梯度变化情况。
第一个例子
本案例通过一个2层的线性网络对hook功能进行测试,分别对每层进行hook,并获得其输入输出值以及梯度。
具体实现的代码如下
import torch
import torch.nn as nn
# 定义hook类
class SaveValues():
def __init__(self, layer):
self.model = None
self.input = None
self.output = None
self.grad_input = None
self.grad_output = None
# 注册hook
self.forward_hook = layer.register_forward_hook(self.hook_fn_act)
self.backward_hook = layer.register_full_backward_hook(self.hook_fn_grad)
# 定义正向传播hook
def hook_fn_act(self, module, input, output):
self.model = module
self.input = input[0]
self.output = output
# 定义反向传播hook
def hook_fn_grad(self, module, grad_input, grad_output):
self.grad_input = grad_input[0]
self.grad_output = grad_output[0]
# 移除hook
def remove(self):
self.forward_hook.remove()
self.backward_hook.remove()
# 定义网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = nn.Linear(2, 5)
self.l2 = nn.Linear(5, 10)
def forward(self, x):
x = self.l1(x)
x = self.l2(x)
return x
# 初始化网络,损失函数,变量
l1loss = nn.L1Loss()
model = Net()
gt = torch.ones((10,), dtype=torch.float32, requires_grad=False)
x = torch.ones((2,), dtype=torch.float32, requires_grad=False)
# 实例化hook
value1 = SaveValues(model.l1)
value2 = SaveValues(model.l2)
# 训练过程
y = model(x)
loss = l1loss(y, gt)
loss.backward()
接下来,我们来调用hook所得的值。我们分查看了两个hook所得的值。
第二个例子
第二个例子我们将猫的照片放入已经训练好的resnet50,观察其产生的feature map和gradient map。
我们用sum聚合了所有通道的值,并进行可视化
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
img_path = "cat.png"
img = Image.open(img_path).convert('RGB')
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
gt = torch.tensor([0])
res50_weight_filename = 'resnet50-19c8e357.pth'
res50_model = models.resnet50(pretrained=False)
res50_model.load_state_dict(torch.load(res50_weight_filename))
loss_function = nn.CrossEntropyLoss()
#取了resnet第二个block(下图绿色部分)的最后一个Bottleneck和第三个block(下图蓝色部分)的最后一个Bottleneckvalue1 = SaveValues(res50_model.layer2[-1])
value2 = SaveValues(res50_model.layer3[-1])
y =res50_model(img)
loss = loss_function(y, gt)
loss.backward()
我们输入的猫图片如图所示:
然后我们将特征图可视化
plt.subplot(2,2,1)
plt.axis('off')
plt.title('input')
plt.imshow(value1.input.detach().numpy()[0,:,:,:].sum(axis=0))
plt.subplot(2,2,2)
plt.axis('off')
plt.title('output')
plt.imshow(value1.output.detach().numpy()[0,:,:,:].sum(axis=0))
plt.subplot(2,2,3)
plt.axis('off')
plt.title('grad_input')
plt.imshow(value1.grad_input.detach().numpy()[0,:,:,:].sum(axis=0))
plt.subplot(2,2,4)
plt.axis('off')
plt.title('grad_output')
plt.imshow(value1.grad_output.detach().numpy()[0,:,:,:].sum(axis=0))
第二个块的最后一个Bottleneck
第三个块的最后一个Bottleneck
参考资料:
https://blog.csdn.net/weixin_41978699/article/details/123102220
https://zhuanlan.zhihu.com/p/87853615
https://github.com/pytorch/pytorch/issues/61519
评论