如何优雅的提取feature map和gradient map(2个简单例子理解pytorch中的hook的使用)

AI算法与图像处理

共 5398字,需浏览 11分钟

 ·

2022-06-17 21:19

简介

我们在训练网络的过程中往往希望知道在forward过程中每层的输入和输出,也想知道backward过程中反传的梯度。

pytorch提供了两个钩子注册函数(register_forward_hook,register_full_backward_hook,用于获取forward和backward中输入和输出,可以在不改变网络的定义代码,不需要在forward函数中return某个感兴趣层的输出,可以获得特定层或者特定block的输入输出以及梯度变化情况。

第一个例子

本案例通过一个2层的线性网络对hook功能进行测试,分别对每层进行hook,并获得其输入输出值以及梯度。

具体实现的代码如下

import torch
import torch.nn as nn

# 定义hook类
class SaveValues():
    def __init__(self, layer):
        self.model  = None
        self.input  = None
        self.output = None
        self.grad_input  = None
        self.grad_output = None
        # 注册hook
        self.forward_hook  = layer.register_forward_hook(self.hook_fn_act)
        self.backward_hook = layer.register_full_backward_hook(self.hook_fn_grad)
    # 定义正向传播hook
    def hook_fn_act(self, module, input, output):
        self.model  = module
        self.input  = input[0]
        self.output = output
    # 定义反向传播hook
    def hook_fn_grad(self, module, grad_input, grad_output):
        self.grad_input  = grad_input[0]
        self.grad_output = grad_output[0]
    # 移除hook
    def remove(self):
        self.forward_hook.remove()
        self.backward_hook.remove()

# 定义网络        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = nn.Linear(25)
        self.l2 = nn.Linear(510)

    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        return x

# 初始化网络,损失函数,变量
l1loss = nn.L1Loss()
model  = Net()
gt = torch.ones((10,), dtype=torch.float32, requires_grad=False)
x  = torch.ones((2,), dtype=torch.float32, requires_grad=False)

# 实例化hook
value1  = SaveValues(model.l1)
value2  = SaveValues(model.l2)

# 训练过程
y = model(x)
loss  = l1loss(y, gt)
loss.backward()

接下来,我们来调用hook所得的值。我们分查看了两个hook所得的值。

第二个例子


第二个例子我们将猫的照片放入已经训练好的resnet50,观察其产生的feature map和gradient map。

我们用sum聚合了所有通道的值,并进行可视化

from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms


data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.4850.4560.406], [0.2290.2240.225])])

img_path = "cat.png"
img = Image.open(img_path).convert('RGB')
img = data_transform(img)
img = torch.unsqueeze(img, dim=0)
gt  = torch.tensor([0])


res50_weight_filename = 'resnet50-19c8e357.pth'
res50_model = models.resnet50(pretrained=False)
res50_model.load_state_dict(torch.load(res50_weight_filename))
loss_function = nn.CrossEntropyLoss()

#取了resnet第二个block(下图绿色部分)的最后一个Bottleneck和第三个block(下图蓝色部分)的最后一个Bottleneckvalue1  = SaveValues(res50_model.layer2[-1])
value2  = SaveValues(res50_model.layer3[-1])

y =res50_model(img)
loss  = loss_function(y, gt)
loss.backward()


我们输入的猫图片如图所示:


然后我们将特征图可视化


plt.subplot(2,2,1)
plt.axis('off')
plt.title('input')
plt.imshow(value1.input.detach().numpy()[0,:,:,:].sum(axis=0))
plt.subplot(2,2,2)
plt.axis('off')
plt.title('output')
plt.imshow(value1.output.detach().numpy()[0,:,:,:].sum(axis=0))
plt.subplot(2,2,3)
plt.axis('off')
plt.title('grad_input')
plt.imshow(value1.grad_input.detach().numpy()[0,:,:,:].sum(axis=0))
plt.subplot(2,2,4)
plt.axis('off')
plt.title('grad_output')
plt.imshow(value1.grad_output.detach().numpy()[0,:,:,:].sum(axis=0))

第二个块的最后一个Bottleneck

第三个块的最后一个Bottleneck



 

参考资料:

https://blog.csdn.net/weixin_41978699/article/details/123102220

https://zhuanlan.zhihu.com/p/87853615

https://github.com/pytorch/pytorch/issues/61519


浏览 166
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报