PyTorch训练模型添加L1/L2正则化的两种实现方式_l2正则化 pytorch_hlld26的博客-程序员秘密

技术标签: 深度学习  pytorch  L1正则化  L2正则化  

L1/L2正则化的作用

L1正则化作用到参数会产生更稀疏的解,既能使参数在训练过程中尽量靠近最优解的同时,一些参数为0。L1正则化的稀疏性质被广泛应用于特征选择,可从特征集合中选出具有代表的特征子集,以此简化机器学习问题。L1正则化的表达式如下:
f ( θ ) = ∣ ∣ θ ∣ ∣ 1 = ∑ i ∣ θ i ∣ f(\theta) = || \theta ||_{1} = \sum_{i} | \theta_{i} | f(θ)=θ1=iθi
由于上述表达式存在绝对值形式,不能直接求导,但可以使用次梯度表示。L1正则化的次梯度如下:
∇ f ( θ ) = s i g n ( θ ) \nabla f(\theta) = sign(\theta) f(θ)=sign(θ)
符号函数sign用于获取输入参数 θ \theta θ逐个元素的正负号。
L2正则化通常又被称为权重衰减,通过添加一个正则化项使得参数在训练时更加接近原点,可防止模型过拟合。更抽象地来说,L2正则化会对减少目标函数无关的部分参数进行衰减,而能显著减少目标函数的部分参数则不受影响。L2正则化通常仅对卷积层的权重进行惩罚,不包括偏置项,对应表达式如下:
f ( θ ) = ∣ ∣ θ ∣ ∣ 2 = 1 2 θ T θ f(\theta) = || \theta ||_{2} = \frac{1}{2} \theta^{T} \theta f(θ)=θ2=21θTθ
不难看出,上述表示式的梯度等于参数本身,既梯度如下:
∇ f ( θ ) = θ \nabla f(\theta) = \theta f(θ)=θ

PyTorch添加L1/L2正则化

在使用PyTorch训练模型时,可使用两种方式添加L1/L2正则化:一种是添加正则化项到损失函数中,另一种是在backward()之后,添加正则化项到参数变量的梯度中,然后再进行step()

方式一:添加到损失函数

def l1_regularization(model, l1_alpha):
    l1_loss = []
    for module in model.modules():
        if type(module) is nn.BatchNorm2d:
            l1_loss.append(torch.abs(module.weight).sum())
    return l1_alpha * sum(l1_loss)

def l2_regularization(model, l2_alpha):
    l2_loss = []
    for module in model.modules():
        if type(module) is nn.Conv2d:
            l2_loss.append((module.weight ** 2).sum() / 2.0)
    return l2_alpha * sum(l2_loss)

方式二:添加到参数梯度

def l1_regularization(model, l1_alpha):
    for module in model.modules():
        if type(module) is nn.BatchNorm2d:
            module.weight.grad.data.add_(l1_alpha * torch.sign(module.weight.data))

def l2_regularization(model, l2_alpha):
    for module in model.modules():
        if type(module) is nn.Conv2d:
            module.weight.grad.data.add_(l2_alpha * module.weight.data)
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/hlld__/article/details/114276683

智能推荐

程序的本质之三ELF文件中与符号(symbol)相关的section的定义_tanglinux的博客-程序员秘密

操作系统:CentOS Linux release 7.7.1908内核版本:3.10.0-1062.1.1.el7.x86_64运行平台:x86_64参考文献:http://refspecs.linuxfoundation.org/本文根据/usr/include/elf.h文件和程序编译的详细过程文中所述的tanglinux.c源码来分析可执行文件中与符号(symbol)...

aes解密设置utf8 php,PHP实现AES加密解密_weixin_39964660的博客-程序员秘密

1、mcrypt_encrypt AES加密,解密1 class Lib_desEnctyp2 {3 private $key = "";4 private $iv = "";56 /**7 * 构造,传递二个已经进行base64_encode的KEY与IV8 *9 * @param string $key10 * @param string...

JavaWeb之JavaScript_达少~的博客-程序员秘密

知识回顾:JavaWeb之Java基础知识增强JavaWeb之JDBCJavaWeb之数据库连接池JavaWeb之HTML&CSS文章目录今日内容JavaScript:今日内容1. JavaScript基础JavaScript:* 概念: 一门客户端脚本语言 * 运行在客户端浏览器中的。每一个浏览器都有JavaScript的解析引擎 * 脚本语言:不需要编译,直接就可以被浏览器解析执行了* 功能: * 可以来增强用户和html页面的交互过程,可以来控制html元素,让页

Qt--模拟按下按键(键盘)_qt 模拟按键_贝勒里恩的博客-程序员秘密

一、前言最近在做QWT开发的时候碰到一个问题,QwtPlotZoomer提供的放大、缩小操作只支持鼠标事件或键盘事件,但是我希望通过点击软件上的按钮去响应放大、缩小操作,但是事件槽函数不是我写的,不知道怎么调用,所以就只能给放大、缩小操作写一个快捷键了。例如:点击键盘I键放大、O键缩小,然后只需要在软件按钮槽函数中模拟按下了I键和O键,就可以响应相应的放大、缩小操作了。二、具体操作//模拟按下键盘I键QWidget *receiver = QApplication::focusWidget();

二维树状数组--poj1195_u010660276的博客-程序员秘密

Language:DefaultMobile phonesTime Limit: 5000MS Memory Limit: 65536KTotal Submissions: 13196 Accepted: 6136DescriptionSuppose that the fourth generation m

ASM(Active Shape Model)算法介绍_青莲太初的博客-程序员秘密

ASM是一种基于点分布模型(Point Distribution Model, PDM)的算法。在PDM中,外形相似的物体,例如人脸、人手、心脏、肺部等的几何形状可以通过若干关键特征点(landmarks)的坐标依次串联形成一个形状向量来表示。本文就以人脸为例来介绍该算法的基本原理和方法。首先给出一个标定好68个关键特征点的人脸面部图片,如下所示:

随便推点

[人工智能教程] 人工智能暑期课实践项目建议_SoftwareTeacher的博客-程序员秘密

哈工大人工智能暑期课实训项目建议这个博客介绍了暑期课实践作业的建议。 时间:7/10 - 7/22. 一周上课, 一周项目实践。 要求:项目实践的过程请用公开的博客记录。 项目的源代码请放到 github 中。 每4 ~ 5 人一个小组,从下面的候选中选择题目:1)手写数字识别增强版。 在MNist 的基础上进一步扩展, 阶段要求: 能实现多个数字的手写体识别 能实现加...

最优化算法——遗传算法解析_二代遗传算法_Miss_yuki的博客-程序员秘密

遗传算法是将生物进化论思想融入算法中来寻找最优值的一种编程方法,采用概率化的寻优方法。可分为四步:1)初始化:设置最大进化代数T(即迭代停止条件),随机生成M个个体作为初始群体。2)适应度:计算M个个体各自的适应度,评价不同个体对环境不同的适应情况,适应度越高被选择的概率越大。3)交叉,变异:增加个体数,产生新鲜染色体,进行筛选,看是否产生适应度更大的个体。4)迭代:直到达到终止条件,选出适应度最...

gpedit msc组策略面板 win10在哪里_Win10怎么添加加组策略功能 Win10组策略gpedit.msc安装方法..._weixin_39885803的博客-程序员秘密

Win10家庭版由于没有组策略gpedit.msc功能,设置很多东西都不方便,组策略应用太广泛了,其实Win10系统还是比较开放式的,我们可以通过一些方法来安装组策略功能,有需要的朋友可以学习一下。安装方法如下:1、先随便找个地方新建个文本文档,把下面这些内容复制进去;@echo offpushd "%~dp0"dir /b C:\Windows\servicing\Packages\Micros...

c语言函数fillpoly,C++_C语言fillpoly函数详解,C语言中,fillpoly函数的功能是 - phpStudy..._weixin_39820185的博客-程序员秘密

C语言fillpoly函数详解C语言中,fillpoly函数的功能是画一个多边形,今天我们就来学习学习。C语言fillpoly函数:填充一个多边形函数名:fillpoly功 能:画并填充一个多边形头文件:#include 原 型:fillpoly(int numpoints, int far *polypoints);参数说明:numpoints 为多边形的边数;far *polypoints...

vue总结_胡须表达式_kochiya_的博客-程序员秘密

Vuees特征变量的声明let a = 20;//使用let声明的是块级的const a = 5;//此时a不能再被赋值解构表达式let arr = [1,2,3]let [a,b,c] = arrlet obj = { name:"Tom", age:23}let {name,age,ss} = obj箭头函数let obj = { show1:function(a) {}, show2:(b)=>{}, show3(c){},

推荐文章

热门文章

相关标签