谷歌提出MaskGIT:掩码生成图像Transformer_Amusi(CVer)的博客-程序员秘密

技术标签: 算法  机器学习  计算机视觉  深度学习  人工智能  

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

转载自:机器之心

来自谷歌研究院的研究者提出了一种使用双向 transformer 解码器的新型图像合成模型 MaskGIT,在性能和速度上都获得了大幅改进。

生成式 transformer 在合成高保真和高分辨率图像方面得到了快速普及。但迄今为止最好的生成式 transformer 模型仍是将图像视为一系列 token,并按照光栅扫描顺序(即逐行)解码图像。然而这种策略既不是最优的,也不高效。

近日,来自谷歌研究院的研究者提出了一种使用双向 transformer 解码器的新型图像合成模型 MaskGIT。在训练期间,MaskGIT 通过关注各个方向的 token 来学习预测随机掩码 token。在推理阶段,模型首先同时生成图像的所有 token,然后以上一次生成为条件迭代地细化图像。实验表明,MaskGIT 在 ImageNet 数据集上显著优于 SOTA transformer 模型,并将自回归解码的速度提高了 64 倍。

4ce83e5ee0570fb9aec542b0bccbc711.png

MaskGIT: Masked Generative Image Transformer

论文地址:https://arxiv.org/abs/2202.04200

此外,该研究还表明 MaskGIT 可以轻松扩展到各种图像编辑任务,例如修复、外推和图像处理。

相关研究

先前的模型 VQVAE 提出分两个阶段在潜在空间中生成图像。

第一个阶段称为 tokenization,其中尝试将图像压缩到离散的潜在空间中,这一阶段主要包含三个部分:

一个编码器 E ,负责学习将图像 x∈9623bdcfd9ca065a1a91df006111c961.png tokenize 成潜在嵌入 E(x);

一个用于最近邻查找 codebook 7de8b48df21c55340c96b004d2455455.png,以将嵌入量化为视觉 token;

一个解码器 G,它根据视觉 token e 预测重建图像6188db97b86d749700ff8ccd2555f3f6.png

第二个阶段首先使用深度自回归模型预测视觉 token 的潜在先验,然后使用第一阶段的解码器将 token 序列映射到图像像素中。

这种两阶段范式是很有效的,因此几种常用的方法都遵循了这种范式,例如 DALL-E、VQGAN。其中,VQGAN 在第一阶段增加了对抗性损失和感知损失以提高图像保真度。

MaskGIT

上述使用两阶段范式的方法由于仍然采用自回归模型,因此第二阶段的解码时间与 token 序列长度成比例。而本研究的目标是设计一种利用并行解码和双向生成的新图像合成范式,遵循上述两阶段方案并改进第二阶段。第一阶段采用与 VQGAN 模型相同的设置,并将潜在的改进留给未来工作的 tokenization 步骤;对于第二阶段,研究者提出通过掩码视觉 token 建模(Masked Visual Token Modeling,MVTM 学习双向 transformer。

806c9b352c7e7e9a35cfa70a6e759aac.png

训练中的 MVTM

该研究用b71a572c201266d63278927c74af53f4.png表示将图像输入到 VQ 编码器获得的潜在 token,其中 N 是重构后的 token 矩阵的长度,575f7cefd74ec7bd6222f30e5289c17e.png 是对应的二进制掩码。在训练期间,该研究采样 token 的子集,并用一个特殊的 [MASK] token 替代它们。如果 m_i=1,就用 [MASK] 取代 token y_i;如果 m_i=0,y_i 保留。

采样过程由掩码调度函数(mask scheduling function)336f0b38cc20f5fa75a4cf2fe7e25f07.png 进行参数化,然后按照如下步骤:

首先从 0 到 1 采样一个比率,然后在 Y 中统一选择 d01e0e5dce3fb9f94f6d3c25ae8c519c.png 个 token 来放置掩码,其中 N 是长度。掩码调度显著影响了图像的生成质量。

迭代解码

在自回归解码中,token 是根据先前生成的输出顺序生成的。这个过程是不可并行的,而图像的 token 长度通常比语言长得多,因此速度非常慢。该研究提出了一种新型解码方法,其中图像中的所有 token 都是同时并行生成的,这基于 MTVM 的双向自注意力。

理论上讲,该模型能够推断出所有 token 并在单次传递中生成整个图像,但训练任务的不一致给该研究带来了挑战。为了在推理时生成图像,该研究从一个空白 canvas 开始,所有 token 都被掩码,即76c86790407984b743b4d6c75ceffbb3.png。该研究提出的迭代解码方法,每次迭代的算法运行步骤如下:

1. 预测

2. 采样

3. 掩码调度

4. 掩码

掩码设计

研究者发现图像的生成质量受到掩码设计的显著影响。该方法通过一个掩码调度函数dd03844385a271404924f62e01ee885f.png对掩码过程进行建模,该函数负责计算给定潜在 token 的掩码比率。在推理期间,函数b8a28ebc117df72cf6aaafb06779f336.png165eb9b7579370332d26d363a66089f0.png的输入代表解码的进度;在训练期间,该研究在 [0,1) 中随机采样一个比率 r 来模拟各种解码场景。

实验

该研究从质量、效率和灵活性方面对 MaskGIT 在图像生成方面进行了实验评估。

类条件图像合成

该研究在 ImageNet 256 X 256 和 ImageNet 512 X 512 上评估了 MaskGIT 模型在类条件(class-conditional)图像合成任务上的性能,主要结果如下表 1 所示。

98f15d78ef9f6e1e8f37fa3f6bb04295.png

质量。在 ImageNet 256 X 256 上,不使用任何特殊的采样策略,MaskGIT 在 FID 和 IS 方面都显著优于 VQGAN。

速度。该研究通过评估每个模型生成样本所需的步骤数(前向传递)来评估模型速度。如表 1 所示,在所有基于非 GAN 的模型中,MaskGIT 在两种分辨率上所需的步骤最少。

为了进一步证实 MaskGIT 和自回归模型之间的速度差异,该研究对 MaskGIT 和 VQGAN 的解码过程进行了运行时比较。如下图 4 所示,MaskGIT 将 VQGAN 显著加速了 30-64 倍,随着图像分辨率(以及输入 token 长度)的增加,加速变得更加明显。

412108180024a020968331a233337982.png

多样性。除了样本质量外,该研究还将分类准确率得分 (CAS) 和 Precision/Recall 作为评估样本多样性的两个指标。与 BigGAN 的样本相比,MaskGIT 的样本更加多样化,具有更多种光照、姿态、规模和语境,如下图 5 所示。

9a31065d669d0b69a9623d680a7b3dfd.png

图像编辑应用

该研究展示了 MaskGIT 在三个图像编辑任务上的直接应用:类条件图像编辑、图像修复和图像扩展(outpainting)。如果将任务看作对初始二进制掩码 M MaskGIT 在其迭代解码中使用约束,那么这三个任务几乎都可以轻松地转换为 MaskGIT 可以处理的任务。

该研究表明,无需修改架构或任何特定于任务的训练,MaskGIT 就能够在所有三个应用程序上产生非常优秀的结果。此外,MaskGIT 在图像修复和扩展方面获得了与专用模型相当的性能。

在类条件图像编辑任务上,该研究定义了一个新的类条件图像编辑任务来展示 MaskGIT 的灵活性。模型在给定类的边界框内重新生成特定内容,同时保留语境,即框外的内容。由于违背了预测顺序,因此自回归方法是不可行的。

然而,对于 MaskGIT,如果将边界框区域视为迭代解码算法的初始掩码的输入,这个问题就迎刃而解了。下图 6 给出了一些示例结果。

f20ebe0a3e338a6824d5ef086455db6b.png

表 2 比较了几种方法的定量结果。MaskGIT 在 FID 和 IS 中均以显著优势击败 DeepFill 和 HiFill,同时获得接近 SOTA 修复方法 CoModGAN 的分数。

741e12a7e07c9033dac82b8db475dd23.png

如下图 7 所示,MaskGIT 还能够在给定相同输入和不同种子的情况下合成不同的结果。

568e8868326dd7c5535d3884776f062b.png

消融实验

为了验证新设计的效用,该研究在 ImageNet 256×256 的默认设置上进行了消融实验。MaskGIT 的一个关键设计是用于训练和迭代解码的掩码调度函数,实验结果如下表 3 和图 8 所示。

c36622a84490461fbb336ca357bfd3e8.png

值得注意的是,如图 8 所示,在相同的设置下,更多的迭代不一定更好:随着迭代次数 T 的增加,除了对数函数在整个过程中都表现不佳以外,其他所有函数都达到了一个「sweet spot」位置,即模型的性能在再次恶化之前达到峰值。

ICCV和CVPR 2021论文和代码下载

后台回复:CVPR2021,即可下载CVPR 2021论文和代码开源的论文合集

后台回复:ICCV2021,即可下载ICCV 2021论文和代码开源的论文合集

后台回复:Transformer综述,即可下载最新的3篇Transformer综述PDF
CVer-Transformer交流群成立
扫描下方二维码,或者添加微信:CVer6666,即可添加CVer小助手微信,便可申请加入CVer-Transformer 微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer等。
一定要备注:研究方向+地点+学校/公司+昵称(如Transformer+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群

▲扫码或加微信: CVer6666,进交流群
CVer学术交流群(知识星球)来了!想要了解最新最快最好的CV/DL/ML论文速递、优质开源项目、学习教程和实战训练等资料,欢迎扫描下方二维码,加入CVer学术交流群,已汇集数千人!

▲扫码进群
▲点击上方卡片,关注CVer公众号

整理不易,请点赞和在看
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/amusi1994/article/details/123650528

智能推荐

golang并发模式之生产者消费者模型(来自《go语言高级编程》)_qq_33332829的博客-程序员秘密

golang实现生产者消费者模型其实很简单:package mainimport ( "fmt" "time")//生产者:生成factor证书的序列func Producer(factor int, out chan<- int) { for i := 0; ; i++ { out <- i * factor }}//消费者func Consumer...

数据库导出数据字典_weixin_34111790的博客-程序员秘密

2019独角兽企业重金招聘Python工程师标准>>> ...

POJ3991 HDU3351 UVALive4733 Seinfeld【水题】_hdu3351stl_海岛Blog的博客-程序员秘密

SeinfeldTime Limit: 1000MS Memory Limit: 65536KTotal Submissions: 1551 Accepted: 722DescriptionI’m out of stories. For years I’ve been writing stories, some rather silly, just to make simple pro...

戴尔服务器改win7系统,戴尔电脑怎么把Win10系统改装win7系统?_大福�mkq的博客-程序员秘密

戴尔电脑怎么把Win10系统改装win7系统?现在几乎所有出厂的电脑,预装的系统都是Win10系统,而且如果贸然装win7的话根本装不上,而今天,小编就带着大家学习一下怎么使用戴尔电脑改装win7系统,一起看看吧!操作方法:1、重启或开机,也就是在出现戴尔Logo的时候,连续按F2进入Bios,切换到Secure Boot页面,回车选择Secure Boot Enable改成Disabled,关闭...

ROS下物体检测和识别功能(ROS下find_object_2d)_pd很不专业的博客-程序员秘密

效果:安装步骤:# ROS Kinetic:$ sudo apt-get install ros-kinetic-find-object-2d$ cd ~/catkin_ws$ git clone https://github.com/introlab/find-object.git src/find_object_2d$ catkin_make启动:$ roslaunch...

Xmega asf 编程_xmega 编程方式_phnbs的博客-程序员秘密

Table ofContents目录 1.      AtmelAVR Studio6.1介绍.21.1.       新项目启动...21.2.       C/C++可执行项目...21.3.       选AtmelBoard模板。...21.4.       选User-Boards模板...2

随便推点

js人脸识别,tracker.js前端人脸识别框架_js tracker_ntotl的博客-程序员秘密

<!doctype html><html><head> <meta charset="utf-8"> <title>基于tracking的取人脸</title> <script src="js/tracking-min.js"></sc

opencv之坑(四)——拟合圆_opencv 拟合圆_光电的一只菜鸡的博客-程序员秘密

建议用最小二乘法拟合圆,下面是参考链接:http://blog.sina.com.cn/s/blog_b27f71160101gxun.html http://www.cnblogs.com/dotLive/archive/2007/04/06/524633.htmlhttp://blog.csdn.net/andylao62/article/details/24522365ht...

把字符串转化成整数(python)_把字符串转化为整数python_Und.的博客-程序员秘密

一,题目描述将一个字符串转换成一个整数,要求不能使用字符串转换整数的库函数。 数值为0或者字符串不是一个合法的数值则返回0输入描述:输入一个字符串,包括数字字母符号,可以为空输出描述:如果是合法的数值表达则返回该数字,否则返回0示例1输入+21474836471a33输出21474836470二,代码...

Apache Ignite(三):核心特性之大数据处理、客户端和部署_weixin_33755649的博客-程序员秘密

本本文是Ignite系列的第三篇介绍性文章,内容整体比较简略,和第二篇文章一起,大体上介绍了Ignite平台的所有关键技术点,方便大家有一个整体的认识,供技术选型时参考。 \一、Spark共享RDD\Apache Ignite提供了一个Spark RDD抽象的实现,他允许跨越多个Spark作业时方便地在内存内共享状态,不管是在同一个应用内部还是在不同的Spark应用之间。 \IgniteRDD作为...

springboot整合dubbo遇到的坑_夜舞岚的博客-程序员秘密

1.项目启动,报错Exception in thread "main" java.lang.NoSuchMethodError: org.springframework.core.annotation.AnnotationAwareOrderComparator.sort(Ljava/util/List;)V————————————————版权声明:本文为CSDN博主「漠北空城」的原创文...

微软专家推荐11个Chrome 插件_aasd6283356的博客-程序员秘密

Web开发人员,需要长时间使用浏览器,尽管Windows10 Edge浏览器启动非常快速,且支持110多种设备,Edge支持基于JS 扩展,但也删除了很多旧功能像Active-X等插件。多数情况下,插件不仅可以解决一些安全问题,而且能够有效的解决浏览器负载问题。会在Chrome中每新打开一个标签页,插件都会自行运行,生成新的插件实例。这就意味着如果你打开10个标签页,并且浏览器有10个插...

推荐文章

热门文章

相关标签