Amazon深度学习工程师总结的分类模型炼丹技巧总结_cosine learning rate decay-程序员宅基地

技术标签: 图像分类  深度学习  

论文名称:Bag of Tricks for Image Classification with Convolutional Neural Networks

论文链接:https://arxiv.org/pdf/1812.01187.pdf

很多时候,外界同学管深度学习算法工程师叫做调参工程师,简单直接的概括了深度学习工程师的工作,搞深度学习的同学自己也经常自嘲,称自己的工作是炼丹,的确,深度学习模型有时候确实很奇妙,而调参在一个模型的优化中起着至关重要的作用,正因为如此,也有越来越多的研究放在了调参这件事上,比如:学习率的优化算法,模型初始化算法等等。

其实,拿一个别人已经训练好的模型(比如ImageNet上预训练的ResNet),直接在自己的数据集上进行finetune,不需要怎么调参,一般都会得到不错的效果,这就是站在巨人的肩膀上,但是如果想继续提高模型的精度,该怎么做?继续调参?还是有一些其他的方法可以采用?

本篇文章就介绍了Amazon工程师总结的分类模型的调参技巧。

证明你的方法有效的最直接方法就是跟其他方法的效果最做对比,要对比当然就需要有一个baseline,这里我们就利用最常用的深度学习模型训练方法先训练一个base model,如下:

  • 训练数据预处理:
  1. 随机旋转一个batch的图像,然后将其编码成32位浮点数[0-255]。
  2. 随机截取一个长宽比在[3/4,4/3]的矩形,矩形面积占图像面积的[0.8,1],截取后将图像resize到224*224。
  3. 按照0.5的比例进行水平翻转。
  4. 对亮度,色度,饱和度进行跳帧。
  5. 增加PCA噪声,噪声分布为正态分布(0,0.1)。
  6. 对图像像素,减均值[123.68,116.779,103.939],除标准差[58.393,57.12,57.375]。
  • 验证数据:
  1. 短边缩放到256
  2. 中间截取224*224
  3. 对图像像素,减均值[123.68,116.779,103.939],除标准差[58.393,57.12,57.375]
  • 参数初始化
  1. 卷积层以及全连接层采用Xavier算法进行初始化。
  2. bn层, γ \gamma γ = 1 , β \beta β = 0
  • 参数优化方法
  1. 梯度采用带动量的梯度优化方法:Nesterov Accelerated Gradient
  2. 学习率:初始学习率0.1,每30个epoch学习率下降为原来的10%
  3. batchsize:256

上面是最常用的深度学习模型的参数设置,我们将其作为我们的模型的baseline,基础有了,下面我们来谈谈如何提升模型效果:

Question 1 : batch size是不是越大越好?

增加batch size可以增加网络的并行度,降低通信消耗,但是使用大的batch size同样也会带来一定的问题,比如:凸优化问题,随着batch size的增加,会增加数据的收敛的难度,换句话说,相同的epoch,使用大的batch size相比较使用小的batch size,小的batch size可能精度会更高一点。

那么我们是不是不该使用大的batch size,当然不是,下面介绍几种方法:

方法一:线性改变学习率

在梯度下降中,由于选择的sample是随机的,所以其梯度下降的方向也是随机的,当提高batch size之后,并不能改变这种随机性,但是由于图像数量的增加,却可以中和掉一部分的噪声,所以这个时候,我们可以增加一部分的学习率,使得学习的步子迈的大一点,比如,在resnet50中,batch size=256,我们选择了lr=0.1,当batchsize增大到b的时候,lr可以调整为0.1*b/256.

方法二:学习率预热

当我们开始训练模型的时候,往往模型的参数都是随机初始化的,并不能代表什么,所以如果此时选择一个较大的学习率,往往会导致模型的不稳定,那么什么是学习率预热,简单来说就是先使用一个较小的学习率,先迭代几个epoch,等到模型基本稳定的时候再用初始设置的学习率进行训练。举个例子,比如预热5个epoch,学习率设置成lr,则前5个epoch可以设置学习率线性递增,即第一个epoch:0.2*lr,第二个:0.4*epoch,依次类推,到第五个变为lr。

方法三:部分BN层 γ \gamma γ设置成0

resnet我们都知道,中间有很多BN层,BN层的提出可以说是模型训练的一个里程碑,它使得模型的训练更加简单,模型收敛更加快速,并且可以使用更大的学习率进行训练,BN层的作用就是对数据进行归一化操作,然后通过设置两个学习参数对归一化进行调整,即:

其中 γ \gamma γ以及 β \beta β是可以学习的, x ^ \widehat{x} x 是对输入进行标准化的结果,通常的做法是将 γ \gamma γ初始化为1, β \beta β初始化为0,这里作者建议在使用resnet的时候,将每个block的最后的BN层的 γ \gamma γ初始化为0,这样无论前面的结果如何,经过这一层都被清零了,block的输出就只有前面的short cut部分,导致每个block的输入都是一样的,作者解释说,这样可以使得网络的训练更加的方便。

方法四:不使用bias decay

在深度学习训练中,decay是一个很好的策略,可以防止参数多大引起过拟合,一般常采用的策略是L2范数,这里作者建议在做decay的时候,只对卷基层以及全连接层的weight加入decay就可以了,不需要对bias进行处理。

方法五:低精度训练

目前的GPU训练基本是采用32位浮点数进行数据存储的,即FP32,但是新的GPU比如Nvidia V100,支持16为浮点数的运行,速度可以提升2-3倍,而且采用更大的batchsize后(当然使用了前面的各种策略来配合batchsize的提升),精度还小有提升。如下面Table3所示。

当然,哪步更有用作者同样给了分析,如下表Table 4,可见,Zero γ \gamma γ带来的精度提升是最大的。

--------------上面讲述了参数的优化,下面来看看网络的优化方法------------

对于resnet结构,我们应该都不陌生,如果不了解可以查看博客:ResNet(Deep Residual Learning for Image Recognition)
在网络设计中,设计模型结构是最难的,下面介绍几个比较成熟的技巧:

  • ResNet B:

由于在ResNet中,每个stage的第一个block会进行图像尺度的缩小,采用了如图Figure1 深蓝色那部分结构,首先是1*1卷积层,stride=2,使得网络的feature map缩小一半,然后再经过3*3的卷积层,但是stride=2的1*1的卷积层会带来一个问题,会损失掉一半的信息(这里论文中提到是3/4),ResNet B修改了这个结构,将stride=2放到了3*3的卷积,这样就不会带来信息的损失了。

  • ResNet C:

使用3个3*3的卷积来替换resnet的7*7的卷积。并且前2个卷积的stride=3,channel=32,最后一个channel=64,这是一个比较老的套路了,在inception网络中被提出。

  • ResNet D:

既然主通道可以通过修改stride来降低信息损失,那么short cut为什么不可以呢?当然可以,作者是怎么做的呢?首先增加了一个2*2的average pooling layer, 设置stride=2,conv层的stride变为1。

对比一下这几个结构的精度,resnet B的提升是最有效的,相对提升了约0.5个百分点resnet C以及resnet D分别又提升了0.2%以及0.3%,整体大约提升了1个百分点。

--------------上面讲述了网络的优化,下面来看看模型的训练技巧--------------

  • 技巧一 cosine learning rate decay:

前面我们提到了训练模型的时候学习率需要warmup,可是在warmup之后,随着epoch的增加,学习率需要适度的调低,这就叫learning rate decay,我们常采用的方法是使用step decay,最简单的比如每20个epoch降低学习率为原来的10%,本篇综述提到了使用cosine learning rate decay,即采用cosine的方式来降低学习率,公式如下:

红色的线为采用cosine decay策略,蓝色的为采用step decay策略,可以发现,cosine decay策略更加的平滑,训练的精度提升也是逐步提升,不像step策略,会有跳跃。不过训到最后,精度基本差不多,个人觉得,step其实也挺好用。

  • 技巧二 Label Smoothing

在分类算法中,我们常采用的是one-hot编码,label smoothing的策略就是在one-hot的基础上,减去一个较小的值,如下公式,作者解释到,这样可以一定程度上减少过拟合,在采用one-hot编码的时候,只需要计算label类别的损失就可以了,采用label smoothing后,不仅仅需要计算label类别的损失,还需要增加其他类别的损失,这样,在one-hot编码的时候,对应目标的输出的目标值是正无穷,这样跟其他类别的差距更大,而增加了smooth label之后,由于引入了参数 σ \sigma σ,所以随着 σ \sigma σ的变化,其目标发生了变化,作者也画出了其目标图,如图figure 4(a),可见Gap基本集中在9左右,实际的实验结果如图Figure 4(b),也证实了这一点,b也符合a中基本都集中在9左右,并且b中smooth明显要比one-hot要小一点。

  • 技巧三:知识蒸馏

知识蒸馏也是提升模型精度的一个方法,知识蒸馏中,一般有一个精度较好的model作为teacher model,利用teacher model去帮助student model训练,比如:可以采用resnet-152作为teacher model,resnet-50作为student model。在利用知识蒸馏的方法进行训练的时候,需要增加用于蒸馏的loss,举个例子,假设p是真实概率,z和r分别是student以及teacher model的全连接层的输出结果,则损失函数为:

这里解释一下这个T(蒸馏温度参数),T是一个使得softmax output更加平滑的参数,以便于student model从teacher model学习参数。

下图是设置不同的T(蒸馏温度)得到的值,可以看到随着T的增大,曲线便的越来越平滑,其实设置这个标签的目的就是软化标签,增加训练难度,这样在inference的时候,将T重新设置为1,有难度的时候都可以表现很好,简单模式下,这样其分类的准确性就会更高了。

  • 技巧四:mixup training

所谓的mixup training,就是每次要取出2张图像,然后将两张图像进行线性组合,得到新的图像,以此来作为新的训练样本,进行网络的训练,如下公式,其中x代表图像数据,y代表标签,则得到的 x ^ \widehat{x} x y ^ \widehat{y} y 则为送入网络的训练样本。

mixup方法主要增强了训练样本之间的线性表达,增强网络的泛化能力,并且使用mixup方法需要较长的时间收敛。

实验结果

Table 6是作者使用上面不同的方法进行的实验结果,其中w/代表的是with,w/o代表的是without,根据实验结果可以发现,在使用了cosine decay,label smooth方法,在ImageNet的结果上,基本会提高一个点左右。采用mixup方法,三个网络也基本一致的提升了,对于蒸馏的效果,resnet网络的效果提升了大约0.3%,但是对于Inception-V3以及MobileNet,精度都下降了,为什么会出现这种情况呢?原因可能是:由于这里是利用ResNet-152作为teacher model的,而ResNet-152的输出的数据分布和Inception以及mobileNet的分布不同,所以导致了结果的不一致性。

Table 7是作者在Places 365数据集的测试结果,结果表明,采用这几个策略进行训练的结果同样也是有效的。

既然我们的模型在分类任务上表现提升了,那么使用此模型,在目标检测以及目标分割上是否有用呢?

首先作者测试了其在Faster R-CNN中的效果,测试结果如下表所示,这里,作者使用的是VGG-19作为backbone,使用不同精度的预训练模型进行训练,可以发现,在使用了精度更高的预训练模型之后,Faster R-CNN的mAP最终提高了大约4%(77.54->81.33)

在图像分割方面,最具代表性的网络就是FCN,作者在FCN网络上测试了不同精度的backbone对FCN的影响,结果如下表所示,可见,采用了作者调优后的模型还是具有一定的效果的,不过这里对于采用了cosine优化的效果最佳,对于采用了label smoothing,mixup等方法效果不是特别的明显,这是为什么呢?猜想应该是图像分割是对每个像素进行分类,而采用了诸如label smoothing方法,本身对于像素的标签产生了一定的影响,采用mixup等方法直接对像素值进行了改变,进而影响了对像素分类的效果

总结

本篇文章是Amazon对于分类模型炼丹方法的一个总结,介绍了很多trick,还是有很多借鉴意义的,至少我现在使用的其提供的resnet模型,在分类效果上确实有一定的提升。

---------------------------------彩弹---------------------------------

对于Amazon这样的大公司,当然也不会只是纸上谈兵,既然讲了这么多的方法,有没有预训练模型提供给我们使用呢?当然是有的,不过由于亚马逊推的是自己的深度学习框架MXNet+Gluon,所以这些预训练模型是在最新的GluonCV的model_zoo中提供。

有兴趣的读者可以自行查看,网站链接如下:

https://gluon-cv.mxnet.io/model_zoo/classification.html

并且,gluoncv中还给出了各个模型的运行时间对比,内存消耗对比等,如下图所示,方便大家根据自己的需求选择合适的模型。

最后附上一段使用gluoncv进行imagenet分类的代码:


import mxnet as mx
import gluoncv

# you can change it to your image filename
filename = 'classification-demo.png'
# you may modify it to switch to another model. The name is case-insensitive
model_name = 'ResNet50_v1d'
# download and load the pre-trained model
net = gluoncv.model_zoo.get_model(model_name, pretrained=True)
# load image
img = mx.image.imread(filename)
# apply default data preprocessing
transformed_img = gluoncv.data.transforms.presets.imagenet.transform_eval(img)
# run forward pass to obtain the predicted score for each class
pred = net(transformed_img)
# map predicted values to probability by softmax
prob = mx.nd.softmax(pred)[0].asnumpy()
# find the 5 class indices with the highest score
ind = mx.nd.topk(pred, k=5)[0].astype('int').asnumpy().tolist()
# print the class name and predicted probability
print('The input picture is classified to be')
for i in range(5):
    print('- [%s], with probability %.3f.'%(net.classes[ind[i]], prob[ind[i]]))
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/chunfengyanyulove/article/details/86670263

智能推荐

ejabberd分析(五)+订阅/添加好友_ejabberd iq-程序员宅基地

文章浏览阅读319次。模块ejabberd_c2s中,状态为session_established2。用户发送iq set 消息到服务器Friends服务器端匹配到[plain] view plaincopyprint?case Name of ...... To = xml:get_attr_s("to", Attrs_ejabberd iq

Python学习笔记02----M*N的棋盘,马从坐下到右上的行走方式_下过象棋的人都知道,马只能走'日'字形(包括旋转90°的日),现在想象一下,给你一个n-程序员宅基地

文章浏览阅读2.4k次。题目:下过象棋的人都知道,马只能走'日'字形(包括旋转90°的日),现在想象一下,给你一个n行m列网格棋盘,棋盘的左下角有一匹马,请你计算至少需要几步可以将它移动到棋盘的右上角,若无法走到,则输出-1.如n=1,m=2,则至少需要1步;若n=1,m=3,则输出-1。#寻找下一步左右可能的点def goNextStep(currentP,n,m): result=[] x=curr..._下过象棋的人都知道,马只能走'日'字形(包括旋转90°的日),现在想象一下,给你一个n

视频集中存储/云存储平台EasyCVR国标GB28181协议接入的报文交互数据包分析_视频磁盘阵列有哪些协议-程序员宅基地

文章浏览阅读1.1k次。设备端才能给服务器传递SIP ID、通道ID以及接入密码等信息。_视频磁盘阵列有哪些协议

部署高可用kubernetes_max-request-bytes-程序员宅基地

文章浏览阅读469次。kubernetes的基本概念写在前面的话整个安装过程中尽量不要出现写死的IP的情况出现,尽量全部使用域名代替IP。环境是ubuntu18.04kubernetes 高可用架构图ETCD高可用API-Server 高可用节点清单制作一个base镜像制作一个base镜像安装和修改通用组件,方便以后的节点部署。修改node的hosts文件如果你使用自己的域名我建议将如下配置配到你的域名管理中。注意:如你使用我这里的域名你需要将此信息写入到机器中的每一台node中(包括master和_max-request-bytes

Java8 Stream:2万字20个实例,玩转集合的筛选、归约、分组、聚合_java8 stream:2万字20个实例-程序员宅基地

文章浏览阅读10w+次,点赞4k次,收藏1.2w次。Java8 Stream横空出世,让我们从繁琐冗长的迭代中解脱出来,集合数据操作变得优雅简洁。这些操作:集合的filter(筛选)、归约(reduce)、映射(map)、收集(collect)、统计(max、min、avg)等等,一行代码即可搞定!让我们一起敲打案例代码,搞定Java8 stream吧!_java8 stream:2万字20个实例

View的事件分发机制(ViewGroup篇)_viewgroup motionevent-程序员宅基地

文章浏览阅读1.1k次。/** * {@inheritDoc} */ @Override public boolean dispatchTouchEvent(MotionEvent ev) { if (mInputEventConsistencyVerifier != null) { mInputEventConsistencyVerifier._viewgroup motionevent

随便推点

python学习[4]: 用python celery + rabbitMQ搭建并行分布式框架及验证_python celery rabbitmq 分布式-程序员宅基地

文章浏览阅读3.3k次。任务解耦(分布式并发处理):假设生产者和消费者分别是两个类。如果让生产者直接调用消费者的某个方法,那么生产者对于消费者就会产生依赖(也就是耦合)。将来如果消费者的代码发生变化,可能会影响到生产者。而如果两者都依赖于某个缓冲区,两者之间不直接依赖,耦合也就相应降低了。生产者直接调用消费者的某个方法,还有另一个弊端:由于函数调用是同步的(或者叫阻塞的),在消费者的方法没有返回之前,生产者只好一直等在那_python celery rabbitmq 分布式

【托福写作】TPO 1_tpo1综合写作-程序员宅基地

文章浏览阅读373次。综合写作[第一次写作]Both in the reading passage and the lecture, there are 3 aspects, which is for the company, for the society, and for the employee himself. And the opinions are totally the opposite.First of all, the four-day week is seen as a costly policy in_tpo1综合写作

[转]The specified module could not be found. (Exception from HRESULT: 0x8007007E)-程序员宅基地

文章浏览阅读2.2k次。问题:I have a managed C++ project (MyLib) that is utilizing 3rd party C++ code and libraries. When I have set a reference to that project (MyLib) I have seen this error. When I put the 3rd party_the specified module could not be found. (exception from hresult: 0x8007007e

linux ln 命令使用参数详解(ln -s 软链接)_软链接ln -s-程序员宅基地

文章浏览阅读1.1w次。source: http://www.jb51.net/LINUXjishu/150570.html作者:佚名 字体:[增加 减小] 来源:互联网 时间:04-04 23:52:55 我要评论这是linux中一个非常重要命令,请大家一定要熟悉。它的功能是为某一个文件在另外一个位置建立一个同不的链接,这个命令最常用的参数是-s,具体用法是:ln -s 源文件 目标文件这是linu_软链接ln -s

[AS3.0]一步一步学ActionScript 3.0(九) -程序员宅基地

文章浏览阅读1.7k次。前两节中,我们讲到了侦听,类与类之前也是可以侦听的,类与类之前的侦听就达到了类和类之前发消息的功能,这其实就是AS3.0中的消息机制。 我们先一个叫做MyClass的类:package net.smilecn{ import flash.display.Sprite; import flash.events.Event; import flash.events

经典的数据库访问接口-程序员宅基地

文章浏览阅读299次。package org.lyq.dao;import java.io.IOException;import java.io.InputStream;import java.sql.CallableStatement;import java.sql.Connection;import java.sql.DriverManager;import java.sql.PreparedStatement..._经典数据库之间接口程序

推荐文章

热门文章

相关标签