超越CBAM,全新注意力机制!GAM:不计成本提高精度(附Pytorch实现)

共 3810字,需浏览 8分钟

 ·

2021-12-18 21:31

↑ 点击蓝字 关注极市平台

作者丨ChaucerG
来源丨集智书童
编辑丨极市平台

极市导读

 

为了提高计算机视觉任务的性能,人们研究了各种注意力机制。然而,以往的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,本文提出了一种通过减少信息弥散和放大全局交互表示来提高深度神经网络性能的全局注意力机制。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

论文链接:https://paperswithcode.com/paper/global-attention-mechanism-retain-information

为了提高计算机视觉任务的性能,人们研究了各种注意力机制。然而,以往的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,本文提出了一种通过减少信息弥散和放大全局交互表示来提高深度神经网络性能的全局注意力机制。


本文引入了3D-permutation 与多层感知器的通道注意力和卷积空间注意力子模块。在CIFAR-100和ImageNet-1K上对所提出的图像分类机制的评估表明,本文的方法稳定地优于最近的几个注意力机制,包括ResNet和轻量级的MobileNet。

1 简介

卷积神经网络已广泛应用于计算机视觉领域的许多任务和应用中。研究人员发现,CNN在提取深度视觉表征方面表现良好。随着CNN相关技术的改进,ImageNet数据集的图像分类准确率在过去9年里从63%提高到了90%。这一成就也归功于ImageNet数据集的复杂性,这为相关研究提供了难得的机会。由于它覆盖的真实场景的多样性和规模,有利于传统的图像分类、表征学习、迁移学习等研究。特别是,它也给注意力机制带来了挑战。

近年来,注意力机制在多个应用中不断提高性能,引起了研究兴趣。Wang等人使用编码-解码器residual attention模块对特征图进行细化,以获得更好的性能。Hu 等人分别使用空间注意力机制和通道注意力机制,获得了更高的准确率。然而,由于信息减少和维度分离,这些机制利用了有限的感受野的视觉表征。在这个过程中,它们失去了全局空间通道的相互作用。

本文的研究目标是跨越空间通道维度研究注意力机制。提出了一种“全局”注意力机制,它保留信息以放大“全局”跨维度的交互作用。因此,将所提出的方法命名为全局注意力机制(GAM)。

2 相关工作

注意力机制在图像分类任务中的性能改进已经有很多研究。

SENet在抑制不重要的像素时,也带来了效率较低的问题。

CBAM依次进行通道和空间注意力操作,而BAM并行进行。但它们都忽略了通道与空间的相互作用,从而丢失了跨维信息。

考虑到跨维度交互的重要性,TAM通过利用每一对三维通道、空间宽度和空间高度之间的注意力权重来提高效率。然而,注意力操作每次仍然应用于两个维度,而不是全部三个维度。

为了放大跨维度的交互作用,本文提出了一种能够在所有三个维度上捕捉重要特征的注意力机制。

3 GAM注意力机制

本文的目标是设计一种注意力机制能够在减少信息弥散的情况下也能放大全局维交互特 征。作者采用序贯的通道-空间注意力机制并重新设计了CBAM子模块。整个过程如图1 所示, 并在公式1和2。给定输入特征映射 , 中间状态 和输出 定义为:

其中 分别为通道注意力图和空间注意力图; 表示按元素进行乘法操作。

通道注意力子模块

通道注意子模块使用三维排列来在三个维度上保留信息。然后,它用一个两层的MLP(多层感知器)放大跨维通道-空间依赖性。(MLP是一种编码-解码器结构,与BAM相同,其压缩比为r);通道注意子模块如图2所示:

空间注意力子模块

在空间注意力子模块中,为了关注空间信息,使用两个卷积层进行空间信息融合。还从通道注意力子模块中使用了与BAM相同的缩减比r。与此同时,由于最大池化操作减少了信息的使用,产生了消极的影响。这里删除了池化操作以进一步保留特性映射。因此,空间注意力模块有时会显著增加参数的数量。为了防止参数显著增加,在ResNet50中采用带Channel Shuffle的Group卷积。无Group卷积的空间注意力子模块如图3所示:

Pytorch实现GAM注意力机制

import torch.nn as nn  
import torch  


class GAM_Attention(nn.Module):  
    def __init__(self, in_channels, out_channels, rate=4):  
        super(GAM_Attention, self).__init__()  

        self.channel_attention = nn.Sequential(  
            nn.Linear(in_channels, int(in_channels / rate)),  
            nn.ReLU(inplace=True),  
            nn.Linear(int(in_channels / rate), in_channels)  
        )  
      
        self.spatial_attention = nn.Sequential(  
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),  
            nn.BatchNorm2d(int(in_channels / rate)),  
            nn.ReLU(inplace=True),  
            nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),  
            nn.BatchNorm2d(out_channels)  
        )  
      
    def forward(self, x):  
        b, c, h, w = x.shape  
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)  
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)  
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)  
      
        x = x * x_channel_att  
      
        x_spatial_att = self.spatial_attention(x).sigmoid()  
        out = x * x_spatial_att  
      
        return out  

  

if __name__ == '__main__':  
    x = torch.randn(1, 64, 32, 48)  
    b, c, h, w = x.shape  
    net = GAM_Attention(in_channels=c, out_channels=c)  
    y = net(x)  

4实验

4.1 CIFAR-100

4.2 ImageNet-1K

4.3 消融实验

参考

[1].Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions

如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛:珠港澳人工智能算法大赛保姆级零基础人工智能教程
算法trick目标检测比赛中的tricks集锦从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述:一文弄懂各种loss function工业图像异常检测最新研究总结(2019-2020)


CV技术社群邀请函 #

△长按添加极市小助手
添加极市小助手微信(ID : cvmart4)

备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳)


即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群


每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~



觉得有用麻烦给个在看啦~  
浏览 269
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报