Pytorch中,动态调整学习率、不同层设置不同学习率和固定某些层训练的方法_pytorch 不同层不同学习率optimizer.param_groups-程序员宅基地

技术标签: Pytorch  

动态调整学习率

三种方法我都写成直接调用的函数了,所以大家复制走拿去用就行了。

第一种 官方例子中按照milestone调整的办法

def adjust_learning_rate(optimizer, epoch, milestones=None):
    """Sets the learning rate: milestone is a list/tuple"""

    def to(epoch):
        if epoch <= args.warmup:
            return 1
        elif args.warmup < epoch <= milestones[0]:
            return 0
        for i in range(1, len(milestones)):
            if milestones[i - 1] < epoch <= milestones[i]:
                return i
        return len(milestones)

    n = to(epoch)

    global lr
    lr = args.base_lr * (0.2 ** n)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

args.warmup和args.base_lr分别是1和初始学习率。函数to的目的就是求出当前epoch在milestone的哪个范围内,不同范围代表不同的衰减率,用返回的数字来区别epoch的范围。

之后声明lr是全局的,这样做可能是因为在函数外部有使用lr的地方,函数内容就直接改变的是全局的lr。

第二种

来看一下旷世开源的shuffleNet系列中使用的学习率变化策略。

用的学习率衰减策略是根据当前迭代次数选取的。

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                    lambda step : (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)

每当运行一次 scheduler.step(),参数的学习率就会按照lambda公式衰减。

 

第三种

CCNet官方源码中改变学习率的方法。这个学习率衰减策略是最常用的,被称作多项式衰减法。

def lr_poly(base_lr, iter, max_iter, power):
    return base_lr*((1-float(iter)/max_iter)**(power))
            
def adjust_learning_rate(optimizer, learning_rate, i_iter, max_iter, power):
    """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
    lr = lr_poly(learning_rate, i_iter, max_iter, power)
    optimizer.param_groups[0]['lr'] = lr
    return lr

NOTE:

看到有些代码用for循环修改optimizer.param_groups,这个group的数目是model.parameters被分为几个字典,就是几个group,每个group有不同的学习率,weight_decay,等等

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=0.005)
    for o in optimizer.param_groups:
        # print(type(o)) # dict
        for k,v in o.items():
            print(k)


####
params
lr   # optimizer.param_groups[0]['lr']就是这个
momentum
dampening
weight_decay
nesterov

 

如果定义optimizer只用了一组parameters,不是用形如:

optimizer = SGD([{params: params_1, 'lr':0.1},

                             {params:params_2, 'lr': 0.2}]

 

那么不需要for循环了,直接访问optimizer.param_groups[0]['lr']修改。

否则还是for循环吧,可读性高。


 

不同层设置不同的学习率

先搭建一个小网络。

import torch.nn as nn
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(3,32,3)
        self.conv2 = nn.Conv2d(32,24,3)
        self.prelu = nn.PReLU()
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                nn.init.constant_(m.bias.data,0)
            if isinstance(m,nn.Linear):
                m.weight.data.normal_(0.01,0,1)
                m.bias.data.zero_()

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        out = self.prelu(out)
        return out

我们现在看看这个模型的modules

model = Net()
for m in model.modules():
    print(m)


'''
Net(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1))
  (prelu): PReLU(num_parameters=1)
)
Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
Conv2d(32, 24, kernel_size=(3, 3), stride=(1, 1))
PReLU(num_parameters=1)

'''

比如说我们想对前两个卷积层的学习率设置为0.2(注意是前两层),对PRelu激活函数中的参数设置为0.02。剩下的参数用0.3的学习率(假象网络比我写的更加大)。

model = Net()
conv_params = list(map(id,model.conv1.parameters()))   #提出前两个卷积层存放参数的地址
conv_params += list(map(id,model.conv2.parameters()))
prelu_params = []
for m in model.modules():    #找到Prelu的参数
    if isinstance(m, nn.PReLU):
        prelu_params += m.parameters()

#假象网络比我写的很大,还有一部分参数,这部分参数使用另一个学习率
rest_params = filter(lambda x:id(x) not in conv_params+list(map(id,prelu_params)),model.parameters())  #提出剩下的参数
print(list(rest_params))
'''
>> []   #是空的,因为我举的例子没其他参数了
'''
import torch.optim as optim

optimizer = optim.Adam([{'params':model.conv1.parameters(),'lr':0.2},
                        {'params':model.conv2.parameters(),'lr':0.2},
                        {'params':prelu_params,'lr':0.02},
                        {'params':rest_params,'lr':0.3}
                        ])

 

固定某些层训练

思路就是利用tensor的requires_grad,每一个tensor都有自己的requires_grad成员,值只能为True和False。

  • 我们对不需要参与训练的参数的requires_grad设置为False。
  • 在optim参数模型参数中过滤掉requires_grad为False的参数。

还是以上面搭建的简单网络为例,我们固定第一个卷积层的参数,训练其他层的所有参数。

  • 需要遍历第一层的参数,然后为其设置requires_grad
model = Net()
for name, p in model.named_parameters():
    if name.startswith('conv1'):
        p.requires_grad = False

import torch.optim as optim
optimizer = optim.Adam(filter(lambda x: x.requires_grad is not False ,model.parameters()),lr= 0.2)

为了验证一下我们的设置是否正确,我们分别看看model中的参数的requires_grad和optim中的params_group()。

for p in model.parameters():
    print(p.requires_grad)
'''
False
False
True
True
True
'''
for p in optimizer.param_groups[0]['params']:
    print(p.requires_grad)
    print(type(p))

'''
True
<class 'torch.nn.parameter.Parameter'>
True
<class 'torch.nn.parameter.Parameter'>
True
<class 'torch.nn.parameter.Parameter'>
'''

能看出优化器仅仅对requires_grad为True的参数进行迭代优化。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_34914551/article/details/87699317

智能推荐

寻找...-程序员宅基地

文章浏览阅读280次。菜鸟程序员寻找人教我会很虚心学习的希望有人愿意教我

opencv cv2.imread()报错: error: (-215:Assertion failed) !_src.empty() in function ‘cv::cvtColor‘_cv2.error: opencv(4.7.0) /io/opencv/modules/imgpro-程序员宅基地

文章浏览阅读2.2k次,点赞2次,收藏3次。v2.error: OpenCV(4.7.0) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cv::cvtColor'_cv2.error: opencv(4.7.0) /io/opencv/modules/imgproc/src/color.cpp:182: error

推荐Windows电脑上最好用的3个mobi阅读器-程序员宅基地

文章浏览阅读9.3k次。mobi格式之所以流行,主要是源于亚马逊官网的电子书格式以及Kindle,但由于存在较多缺陷逐渐被azw3格式替代,mobi格式的文件无法直接用电脑打开,需要阅读器才能打开阅读。今天小编就为大家推荐3个在Windows电脑上最好用的mobi阅读器。第一款: neat reader neat reader是一款跨平台阅读器,可以支持azw3/azw/mobi/epub/pdf/txt等常见文档格式,也同时支持在Windows/Mac/Android/iOS系统上使用。 neat reader阅读..

生活经验_生活经验博客-程序员宅基地

文章浏览阅读212次。坚持原则1、背后不议论他人是非,拒绝人身攻击。2、坚持宁缺毋滥原则,重要事情坚决不将就或迁就。3、抵制诱惑,坚决不为新奇而体验令人上瘾的不好事物。4、坚持独立思考,明辨是非,科学评判,拒绝听风就是雨,跟风,为舆论绑架。5、若想人不知,除非己莫为。方法论1、做事讲究成本:包括经济成本,时间成本,体能成本。譬如在一下破解应用上,往往找破解方法会花费大量的时间成本,破解应用,往往也可能..._生活经验博客

RK3399平台入门到精通系列讲解 - 总目录_rk3399平台开发系列讲解(内核入门篇)1.20、container_of 获取结构体首地址原理-程序员宅基地

文章浏览阅读8.7w次,点赞261次,收藏1.1k次。欢迎大家来到内核笔记的《RK3399平台开发入门到精通系列讲解》,开始前博主先列出RK3399平台学习的大纲,同时这也可以作为大家学习RK3399内核与安卓框架的参考。下面蓝字都是传送门,点击进入即可:..._rk3399平台开发系列讲解(内核入门篇)1.20、container_of 获取结构体首地址原理

【genius_platform软件平台开发】第五十二讲:Linux系统之V4L2视频驱动详解_v4l2摄像头驱动课程-程序员宅基地

文章浏览阅读861次。1. 简介Video4Linux2(简称V4L2)是Linux下关于视频采集相关设备的驱动框架,为驱动和应用程序提供了一套统一的接口规范。支持三类设备,分别会在/dev目录下产生不同的设备节点:1.1 视频输入输出设备(video capture device,video output device)分别是提供视频捕获功能的摄像头类型设备和提供视频输出功能的设备,对应的设备名为videoX。这是我们最常用的一种设备类型。1.2 VBI设备(Vertical Blanking Interval_v4l2摄像头驱动课程

随便推点

字节跳动高频100道核心前端面试题解析-程序员宅基地

文章浏览阅读3.4k次。字节跳动的前端一直是大热之选,薪资和技术都是国内的最前沿。本文将为大家简单介绍一下字节技术岗的职级体系和相应的技术要求。并给大家分享一套高质量面试题:「由字节资深前端大佬整理的100道高频..._字节跳动前端面试题

HTML-Emmet(神器)_html 神器-程序员宅基地

文章浏览阅读709次。使用Emmet必须先安装插件,我用atom 安装插件翻墙什么的我就不写了 在这里我推荐两个学习emmet语法的网页 日常链接嘻嘻 前人总结的emmet语法使用,有动态图演示,新手推荐 官方语法_html 神器

Zigbee入门概念及背景知识_zigbee背景-程序员宅基地

文章浏览阅读700次。在学习Zigbee之前,需要了解与单片机学习的方法不同之处以及难点所在。学习Zigbee需要掌握协议栈和网络等相关知识,任务量较大。Zigbee的资料相对较少,初学者学习起来比较费劲,学习效果不理想。学习过程中需要利用软件和硬件工具,提高学习效率。Zigbee是一种无线通信方式,用于构建无线局域网,可以用于家居、工业、矿产、农业、医疗等领域。相比蓝牙和WIFI,Zigbee可以组建大规模网络,功耗低,但通信速率较小。Zigbee和以太网组网技术有所不同,用途、拓扑结构和通信特性等方面存在差异。Zigbee的_zigbee背景

HTTP中的GET和POST方法详解_http post get-程序员宅基地

文章浏览阅读1.4w次,点赞18次,收藏100次。一般来说GET是获取数据,POST是提交数据的。但是因为GET和POST都是HTTP的方法,HTTP又是是基于TCP/IP的关于数据在万维网中如何让通讯的协议。从本质上讲,GET和POST都是HTTP请求,都是TCP链接,是无区别的。但是HTTP协议既然有了这两个方法,就是为了在特定的情况下区分应用。1、GET是获取数据,POST是提交数据的。GET方法通常用于请求服务器发送某个资源,而且应该是安全的和幂等的。仅仅是获取资源信息,就像数据库查询一样,不会修改和增加数据,不会影响资源的状态。POST_http post get

初识 MongoDB_本关任务:根据编程要求启动 mongodb 服务。-程序员宅基地

文章浏览阅读2.6k次,点赞3次,收藏28次。第1关:启动 MongoDB本关任务:根据编程要求启动 MongoDB 服务。第2关:启动 MongoDB 多实例本关任务:根据第一关单实例(服务)的启动教程,按照编程要求,启动两组实例(服务)。第3关:退出客户端和关闭 MongoDB 服务本关任务:关闭端口为27017的 MongoDB 服务。标题..._本关任务:根据编程要求启动 mongodb 服务。

ClusterStorage-236-4-客户端配置挂载与授权控制(ACL&Quota)_acl_enable 和 quota-程序员宅基地

文章浏览阅读225次。0.实验环境图1.客户端配置挂载在workstation上,安装glusterfs文件客户端,创建挂载目录,编辑挂载配置文件,进行挂载,查看文件系统。[root@workstation ~]# yum install -y glusterfs-fuse[root@workstation ~]# mkdir /test[root@workstation ~]# mkdir ..._acl_enable 和 quota

推荐文章

热门文章

相关标签