[论文理解] CBAM: Convolutional Block Attention Module

CBAM: Convolutional Block Attention Module

简介

本文利用attention机制,使得针对网络有了更好的特征表示,这种结构通过支路学习到通道间关系的权重和像素间关系的权重,然后乘回到原特征图,使得特征图可以更好的表示。

Convolutional Block Attention Module

这里的结构有点类似与SENet里的支路结构。

对于Channel attention module,先将原feature map分别做global avg pooling 和global max pooling,然后将两pooling后的向量分别连接一个FC层,之后point-wise相加。激活。

这里用global pooling的作用是捕捉全局特征,因为得到的权重描述的是通道间的关系,所以必须要全局特征才能学习到这种关系。

之所以avg pooling和max pooling一起用,是因为作者发现max pooling能够捕捉特征差异,avg pooling能捕捉一般信息,两者一起用的效果要比单独用的实验结果要好,。

结构如图:

对于Spatial attention module,作者使用了1×1的pooling,与上面一样,使用的是1×1的avg pooling和1×1的max pooling,而没有用1×1卷积,两者concat,紧接着是一层7×7卷积,然后激活。最后输出就是1×h×w。

结构如图:

作者提到了两者的顺序,先做channel attention比先做spatial attention要好很多。

后面作者实验了spatial attention module里1×1conv、1×1pooling的效果,最后发现pooing的效果要比卷积的效果要好,因此上面的结构采用的是pooling而不是卷积结构。

后面就是一些结构了。

几句话简单复现了一下。

'''
@Descripttion: This is Aoru Xue's demo,which is only for reference
@version:
@Author: Aoru Xue
@Date: 2019-09-12 01:24:03
@LastEditors: Aoru Xue
@LastEditTime: 2019-09-12 02:24:25
'''
import torch
import torch.nn as nn

class ChannelAttentionModule(nn.Module):
    def __init__(self,size = 128,r = 2):
        super(ChannelAttentionModule, self).__init__()
        self.max_pooling = nn.MaxPool2d(size)
        self.avg_pooling = nn.AvgPool2d(size)
        self.fc1 = nn.Linear(64,64//r)
        self.fc2 = nn.Linear(64//r,64)
        self.relu = nn.ReLU(inplace=True)
    def forward(self,x):
        max_pool = self.max_pooling(x).view(2,64)
        max_pool = self.fc1(max_pool)
        avg_pool = self.avg_pooling(x).view(2,64)
        avg_pool = self.fc1(avg_pool)
        t = max_pool + avg_pool
        x = self.fc2(t).view(2,64,1,1)
        x = self.relu(x)
        return x
class SpatialAttentionModule(nn.Module):
    def __init__(self,):
        super(SpatialAttentionModule, self).__init__()
        self.conv7x7 = nn.Conv2d(2,64,kernel_size= 7 , stride=1,padding = 3)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        max_pool = torch.max(x,dim = 1)[0]
        avg_pool = torch.mean(x,dim = 1)
        x = self.conv7x7(torch.stack([max_pool,avg_pool],dim = 1))
        x = self.sigmoid(x)
        return x
class ResBlock(nn.Module):
    def __init__(self,):
        super(ResBlock, self).__init__()
        self.channel_module = ChannelAttentionModule(r = 2)
        self.spatial_module = SpatialAttentionModule()
    def forward(self,x):
        c = self.channel_module(x)
        x = c*x
        s = self.spatial_module(x)
        x = s * x
        return x
if __name__ == "__main__":
    x = torch.randn(2,64,128,128)
    net = ResBlock()
    print(net(x).size())

原文地址:https://www.cnblogs.com/aoru45/p/11509797.html

时间: 2024-10-31 06:47:02

[论文理解] CBAM: Convolutional Block Attention Module的相关文章

【CV中的Attention机制】易于集成的Convolutional Block Attention Module(CBAM模块)

前言: 这是CV中的Attention机制专栏的第一篇博客,并没有挑选实现起来最简单的SENet作为例子,而是使用了CBAM作为第一个讲解的模块,这是由于其使用的广泛性以及易于集成.目前cv领域借鉴了nlp领域的attention机制以后生产出了很多有用的基于attention机制的论文,attention机制也是在2019年论文中非常火.这篇cbam虽然是在2018年提出的,但是其影响力比较深远,在很多领域都用到了该模块,所以一起来看一下这个模块有什么独到之处,并学着实现它. 1. 什么是注意

[论文理解]Region-Based Convolutional Networks for Accurate Object Detection and Segmentation

Region-Based Convolutional Networks for Accurate Object Detection and Segmentation 概括 这是一篇2016年的目标检测的文章,也是一篇比较经典的目标检测的文章.作者介绍到,现在表现最好的方法非常的复杂,而本文的方法,简单又容易理解,并且不需要大量的训练集. 文章的大致脉络如图. 产生region proposal 文章提到了滑窗的方法,由于滑窗的方法缺点非常明显,就是每次只能检测一个aspect ratio,所以确

[论文理解] Making Convolutional Networks Shift-Invariant Again

Making Convolutional Networks Shift-Invariant Again Intro 本文提出解决CNN平移不变性丧失的方法,之前说了CNN中的downsample过程由于不满足采样定理,所以没法确保平移不变性.信号处理里面解决这样的问题是利用增大采样频率或者用抗混叠方法,前者在图像处理里面设置stride 1就可实现,但stride 1已经是极限,本文着重于后者,使用抗混叠使得CNN重新具有平移不变性. 混叠是在采样频率不满足采样定理时出现的一种现象,抗混叠通过抗

论文: Deformable Convolutional Networks

论文: Deformable Convolutional Networks CNN因为其内部的固定的网络结构,对模型几何变换的识别非常有限. 本paper给出了两个模块deformable convolution 和 deformable ROI-Pooling来提高CNN的模型变换能力. 过去的办法解决几何变换的方法,一,使用data Augmentation来增大不同几何形状的object,二,使用sift 或者 sliding windows这样的方法来解决. 本paper主要针对三个mo

[论文理解]关于ResNet的进一步理解

[论文理解]关于ResNet的理解 这两天回忆起resnet,感觉残差结构还是不怎么理解(可能当时理解了,时间长了忘了吧),重新梳理一下两点,关于resnet结构的思考. 要解决什么问题 论文的一大贡献就是,证明了即使是深度网络,也可以通过训练达到很好的效果,这跟以往的经验不同,以往由于网络层数的加深,会出现梯度消失的现象.这是因为,在梯度反传的时候,由于层数太深,传递过程又是乘法传递,所以梯度值会越乘越小,梯度消失在所难免.那么怎么才能解决这个问题呢?resnet提供了很好的思路. 怎么解决

论文笔记之:Deep Attention Recurrent Q-Network

Deep Attention Recurrent Q-Network 5vision groups  摘要:本文将 DQN 引入了 Attention 机制,使得学习更具有方向性和指导性.(前段时间做一个工作打算就这么干,谁想到,这么快就被这几个孩子给实现了,自愧不如啊( ⊙ o ⊙ ))   引言:我们知道 DQN 是将连续 4帧的视频信息输入到 CNN 当中,那么,这么做虽然取得了不错的效果,但是,仍然只是能记住这 4 帧的信息,之前的就会遗忘.所以就有研究者提出了 Deep Recurre

深入理解nodejs 中 exports与module.exports

在Javascript 中,有2种作用域,分为 全局作用域 ,和函数作用域, 在 浏览器端 , 全局作用域 就是 window对象的属性, 函数作用域 就是 ,某个 函数 生成的对象的属性: <!DOCTYPE html> <html> <head lang="en"> <meta charset="UTF-8"> <title></title> <script> var name

《Real-Time Compressive Tracking》论文理解

     这是Kaihua Zhang发表在ECCV2012的paper,paper的主题思想是利用满足压缩感知(compressive sensing)的RIP(restricted isometry property)条件的随机测量矩阵(random measurement matrix)对多尺度(multiple scale)的图像特征(features)进行降维,然后通过朴素贝叶斯分类器(naive Bayes classifier)对特征进行分类预测目标位置.   首先介绍下paper

[论文理解] CornerNet: Detecting Objects as Paired Keypoints

[论文理解] CornerNet: Detecting Objects as Paired Keypoints 简介 首先这是一篇anchor free的文章,看了之后觉得方法挺好的,预测左上角和右下角,这样不需要去管anchor了,理论上也就w*h个点,这总比好几万甚至好几十万的anchor容易吧.文章灵感来源于Newell et al. (2017) on Associative Embedding in the context of multi-person pose estimation