MLP-Mixer: 基于多层感知机体系结构的图像识别网络_rearrange('b n d -> b d n') pytorch-程序员宅基地

技术标签: 论文阅读  MLP-Mixer  人工智能  

整体导读

卷积神经网络(CNNs)是计算机视觉的主流模型,近年来,基于注意力的网络,如vision transformer也得到了广泛的应用。2021年3月4日,谷歌人工智能研究院Ilya Tolstikhin, Neil Houlsby等人研究员提出一种基于多层感知机结构的MLP-Mixer并在顶会“Computer Vision and Pattern Recognition(CVPR)”上发表一篇题为“MLP-Mixer: An all-MLP Architecture for Vision”的文章。MLP-Mixer包含两种类型的MLP层:一种是独立应用于图像patches的MLP(即“混合”每个位置特征),另一种是跨patches应用的MLP(即“混合”空间信息)。当在大数据集上训练时,或使用正则化训练方案时,MLP-Mixer在图像分类基准上获得有竞争力的分数,并且预训练和推理成本与最先进的模型相当。作者希望这些结果能激发出更深入的研究,超越成熟的CNN和transformer领域。

论文名称:MLP-Mixer: An all-MLP Architecture for Vision

原文链接:https://arxiv.org/pdf/2105.01601.pdf

会议名称:Computer Vision and Pattern Recognition

值得注意的是最近LeCun 的 twitter发表观点,其实MLP-Mixer在很多地方使用卷积,只是用一堆奇奇怪怪的词来描述自己在做的运算。

感兴趣的同学可以读一下https://mp.weixin.qq.com/s/cl5QCRgeY8zWc3AwE4waZQ

MLP模块代码

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

Mixer Layer代码

class MixerBlock(nn.Module):

    def __init__(self, dim, num_patch, token_dim, channel_dim, dropout=0.):
        super().__init__()

        self.token_mix = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n d -> b d n'),
            FeedForward(num_patch, token_dim, dropout),
            Rearrange('b d n -> b n d')
        )

        self.channel_mix = nn.Sequential(
            nn.LayerNorm(dim),
            FeedForward(dim, channel_dim, dropout),
        )

    def forward(self, x):
        x = x + self.token_mix(x)

        x = x + self.channel_mix(x)

        return x

整体模型代码

class MLPMixer(nn.Module):

    def __init__(self, in_channels, dim, num_classes, patch_size, image_size, depth, token_dim, channel_dim):
        super().__init__()

        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        self.num_patch = (image_size // patch_size) ** 2
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels, dim, patch_size, patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )

        self.mixer_blocks = nn.ModuleList([])

        for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim))

        self.layer_norm = nn.LayerNorm(dim)

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):

        x = self.to_patch_embedding(x)

        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)

        x = self.layer_norm(x)

        x = x.mean(dim=1)

        return self.mlp_head(x)

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_38784098/article/details/116571813

智能推荐

html Canvas粒子文字特效_html canvas 效果-程序员宅基地

文章浏览阅读757次,点赞19次,收藏9次。文字动态特效_html canvas 效果

el-table-column 表格列自适应宽度的组件封装说明

针对组件业务上的需求,需要给 el-table-column 加上限制,需保证表头在一行展示,部分列的内容要一行展示,自适应单项列的宽度;

Ali-Sentinel-链路控制

Ali-Sentinel-链路控制

C语言实现SM4(基于GMSSL)_使用c语言调用openssl实现sm4代码-程序员宅基地

文章浏览阅读4.2k次。环境:vs2019 gmssl 32位编译1、首先新建项目2、在VS的工程设置工程属性(参考连接https://blog.csdn.net/zhonghua_csdn/article/details/99011892)右击工程名 ——> 选择“属性” 在“VC++目录”——> “包含目录”中添加openSSL的include文件(在您安装openssl的文件下) 在“VC++目录”——> “库目录”中添加openSSL的lib文件(在您安装openssl的文件下) 在“._使用c语言调用openssl实现sm4代码

让Windows免疫Autorun病毒-程序员宅基地

文章浏览阅读73次。来源:http://www.bysjhf.com.cn目前,U盘病毒的情况非常严重,几乎所有带病毒的U盘,根目录里都有一个autorun.inf。右键菜单多了“自动播放”、“Open”、“Browser”等项目。由于我们习惯用双击来打开磁盘,但现在我们双击,通常不是打开U盘,而是让autorun.inf里所设的程序自动播放。所以对于很多人来说相当麻烦。其实Autorun...._linux怎么为windows做autorun免疫

随便推点

Qt报错:Error while building/deploying project *** (kit: Desktop Qt 5.12.9 MSVC2017...)_error while building/deploying project xianzhazhi -程序员宅基地

文章浏览阅读1.5k次。Qt Creator 报错:Error while building/deploying project helloworld (kit: Desktop Qt 5.6.2 MinGW 32bit) When executing step "qmake" - zhangjunwu - 博客园 (cnblogs.com)https://www.cnblogs.com/zhangjunwu/p/7417566.html注意:Qt文件路径不要出现中文名字和空格!!!......_error while building/deploying project xianzhazhi (kit: desktop qt 5.12.9 ms

解决create-react-app创建项目出错_installing packages. this might take a couple of m-程序员宅基地

文章浏览阅读1.3k次。Installing packages. This might take a couple of minutes.Installing react, react-dom, and react-scripts with cra-template-typescript...npm ERR! code 1npm ERR! path C:\Users\MHX\Desktop\react-demo\node_modules\canvasnpm ERR! command failednpm ERR! comm_installing packages. this might take a couple of minutes. installing react,

关于西电计科本科学习的一些经验分享与资料汇总_西电毕设拿良容易吗-程序员宅基地

文章浏览阅读1.9w次,点赞43次,收藏214次。关于西电计科本科学习的一些经验分享与资料汇总_西电毕设拿良容易吗

【nodejs】使用express-generator快速搭建项目框架-程序员宅基地

文章浏览阅读279次,点赞9次,收藏3次。项目根目录打开终端,执行以下命令,安装依赖。执行以下命令后,在浏览器中打开。就可以打开这个项目了。

c++二维vector_c++ 二维vector-程序员宅基地

文章浏览阅读8.5k次,点赞4次,收藏24次。关于C++中二维vector使用vector本来就是可以用来代替一维数组的,vector提供了operator[]函数,可以像数组一样的操作,而且还有边界检查,动态改变大小。这里只介绍用它来代替二维的数组,二维以上的可以依此类推。1、定义二维vectorvector<vector<int>> A;//错误的定义方式vector<vector<int> > A;//正缺的定义方式vector<vector<int> > v;/_c++ 二维vector

python算法题_python算法题-程序员宅基地

文章浏览阅读187次。广告关闭腾讯云11.11云上盛惠 ,精选热门产品助力上云,云服务器首年88元起,买的越多返的越多,最高返5000元!导言:记录下学习的算法题,写练多,脑子才能转的快! 今日算法题:二分法查找说下我对于二分法查找的理解:【和猜数字游戏差不多】 要在一个有序数列中找到一个与对应给定数字。 1、找到有序数列中最中间的数字2、若中间值大于给定值,则在左边数列重新二分查找3、若中间值小于给定值,则在右边数列..._python服务端算法题