在Pytorch下,由于反向传播设置错误导致 loss不下降的原因及解决方案*_loss一直不降是反向传播的问题吗?-程序员宅基地

技术标签: Pytorch  梯度更新  深度学习  VGG  反向传播  

在Pytorch下,由于反向传播设置错误导致 loss不下降的原因及解决方案

本人研究生渣渣一枚,第一次写博客,请各路大神多多包含。刚刚接触深度学习一段时间,一直在研究计算机视觉方面,现在也在尝试实现自己的idea,从中也遇见了一些问题,这次就专门写一下,自己由于在反向传播(backward)过程中参数没有设置好,而导致的loss不下降的原因。

对于多个网络交替

【描述】简单描述一下我的网络结构,我的网络是有上下两路,先对第一路网络进行训练,使用groud truth对这一路的结果进行监督loss_steam1,得到训练好的feature.然后再将得到的feature级联到第二路,通过网络得到最后的结果,再用groud truth进行监督loss。【整个网络基于VGG19网络,在pytorch下搭建,有GPU环境】:

在这里插入图片描述

出现的情况,loss_steam1不怎么下降

这个问题确实折麽自己一段时间,结果发现自己出现了一个问题,下面将对这个问题进行分析和解答:

PyTorch梯度传递

在PyTorch中,传入网络计算的数据类型必须是Variable类型, Variable包装了一个Tensor,并且保存着梯度和创建这个Variablefunction的引用,换句话说,就是记录网络每层的梯度和网络图,可以实现梯度的反向传递.
则根据最后得到的loss可以逐步递归的求其每层的梯度,并实现权重更新。

在实现梯度反向传递时主要需要三步:

  1. 初始化梯度值:net.zero_grad() 清除网络状态
  2. 反向求解梯度:loss.backward() 反向传播求梯度
  3. 更新参数:optimizer.step() 更新参数

解决方案

自己在写代码的时候,还是没有对自己的代码搞明白。在反向求解梯度时,对第一路没有进行反向传播,这样肯定不能使这一路的更新,所以我就又加了一步:
loss_steam1.backward( retain_graph = True) //因为每次运行一次backward时,如果不加retain_graph = True,运行完后,计算图都会free掉。
loss.backward()

这样就够了么?我当时也是这么认为的结果发现loss_steam1还是没有降,又愁了好久,结果发现梯度有了,不更新参数,怎么可能有用!
optimizer_steam1.step() //这项必须加
optimizer.step()

哈哈!这样就完成了,效果也确实比以前好了很多。

参考博客

https://blog.csdn.net/u011276025/article/details/76997425

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_19329785/article/details/84260201

智能推荐

linux命令行模式登陆_linux unset lang-程序员宅基地

文章浏览阅读1.1k次。设置linux默认登陆模式开机以命令模式启动,执行:systemctl set-default multi-user.target开机以图形界面启动,执行:systemctl set-default graphical.targetlinux命令行模式登陆乱码修改/etc/default/locale命令:sudo vim /etc/default/locale1将下面这两行LANG=zh_CN.UTF-8LANGUAGE=zh_CN:zh替换为LANG="en_US.UTF-_linux unset lang

python实战--画小猪佩奇_let it go小猪佩奇-程序员宅基地

文章浏览阅读1.4k次。放个效果图:除了脑袋大效果勉强出来了,下边放代码# coding:utf_8import turtle as t# python引用turtle画小猪佩奇t.pensize(4)t.hideturtle()t.colormode(255)t.color((255,155,192),"pink")t.setup(840,500)# 画框大小t.speed(10)#画笔速度#..._let it go小猪佩奇

scp远程复制文件和目录_scp 远程复制目录-程序员宅基地

文章浏览阅读1.7w次。注意scp只能在linux操作系统平台上,要想在linux与window平台上传文件或者目录,下载一个winscp软件或者下载一个sshsecure shell软件安装在window上即可,非常方便,直接拖拉就行。1.上传本机文件到远程服务器 scp local_path/file_name user_name@remote_ip:remote_path/2.上传本机目录到远程服务器 scp ..._scp 远程复制目录

Excel读取wincc归档数据_excel 通过opcua获取wincc归档数据-程序员宅基地

文章浏览阅读3.7k次,点赞3次,收藏10次。1、先启动wincc 双击计算机,并勾选变量记录运行系统,在激活项目 点击变量管理,点击变量记录,归档名tank_archives _excel 通过opcua获取wincc归档数据

BUUCTF [De1CTF2019]Mine Sweeping17刷题笔记_ctf assembly-csharp.dll-程序员宅基地

文章浏览阅读785次。工具: 反编译工具 dnspy过程:打开压缩包,发现是一个扫雷游戏,选择反编译工具通过修改源码完成游戏,得到flag1.用dnspy打开Assembly-CSharp.dll文件路径:\Mine Sweeping\Mine Sweeping_Data\Managed\2.打开后找到使游戏结束的代码段 如下3.经分析,将this.bIsMine改成false后,即使点击到雷,游戏也不会结束,具体操作如下4.保存修改后的模块5.再次打开扫雷游戏,把所有的..._ctf assembly-csharp.dll

linux死机处理(我仅仅会使用一种方法)_linux虚拟机死机-程序员宅基地

文章浏览阅读2.2w次,点赞8次,收藏45次。系统环境:Ubuntu 16.04虚拟机:VM12Linux 死机有很多种情况,最常见的是系统负载过高导致的。可以是运行内存耗用极大的程序,也会迅速提升系统负载。由于系统负载过高导致的卡死,一定是解决的越快越好!不能再试图依赖任何图形界面的东西,因为鼠标都没有用,而且使用开启终端命令也没有用。首先 Ctrl + Alt +(F1-F6)中,进_linux虚拟机死机

随便推点

mybatis-plus 关于savebatch,saveorupdatebatch遇到的坑及解决办法-程序员宅基地

文章浏览阅读7.8w次,点赞26次,收藏139次。一.背景 最近mybatis-plus框架的更新,让我们基础开发中如虎添翼。其中基本的增删改查,代码生成器想必大家用着那叫一个爽。本人在使用中,也遇到一些坑。比如savebatch,saveorupdatebatch,看着这不是批量新增,批量新增或更新嘛,看着api进行开发,感觉也太好用啦。开发完一测试,速度跟蜗牛一样,针对大数据量真是无法忍受。在控制台上发现,怎么名义上是批量插入,还是一条一条的进行插入,难怪速度龟速。二.解决办法 查阅网上资料,大体有两种解决方案:..._saveorupdatebatch

LCD段码屏的功耗大吗?_段码屏功耗-程序员宅基地

文章浏览阅读2.2k次。LCD段码屏功耗不大,很小。功耗虽然等于电压和电流的乘积,但在使用中却有独立的意义,他标志着器件消耗电能的多少,这在微型,便携设备上意义重大。在主要的平板显示器件中,PDP,FED,VFD的功耗大,而EL,LED的功耗次之。目前有人称OLED的功耗比液晶显示还低,这是个误解。OLED的功耗和LED功耗在同一数量级,但是它是主动发光器件,不需背光源,而且只有在显示时才耗电,因此和增加了背光源的液晶显..._段码屏功耗

【AltiumDesigner18】关于modified polygon的一种解决方式_modified polygon错误怎么修改-程序员宅基地

文章浏览阅读1.9w次,点赞15次,收藏20次。问题描述:铺铜后进行DRC进行出现modified polygon冲突。如何解决:参考网上乱七八糟的解决方式未果后,我尝试看了下AD的用户手册,查找到相应部分描述如下:很显然这个冲突的出现,最可能的原因是因为铺铜中有铜块被搁置或未导入。对应规则:用户手册里也给出了参考的解决方案:显然只要将无用铜块进行合理修改或者删除即可,于是打开铺铜管理器->进行相应操作,冲突即可消除。例如:..._modified polygon错误怎么修改

pytorch二分类损失函数BCEWithLogitsLoss_criterion = nn.bcelogitsloss()-程序员宅基地

文章浏览阅读2.8k次。一个正确的语义分割二分类损失函数的计算应该是如下这样的criterion = nn.BCEWithLogitsLoss(weight=None)masks_pred = net(imgs)loss = criterion(masks_pred, true_masks) #使用二分类交叉熵epoch_loss += loss.item()先前在网络最后输出时使用sigmoid,然后使用B..._criterion = nn.bcelogitsloss()

《深入理解MySQL主从原理32讲》推荐篇-程序员宅基地

文章浏览阅读1.1w次,点赞12次,收藏90次。导读:作者:高鹏(网名八怪),《深入理解MySQL主从原理32讲》系列文的作者。2008年开始至今一直从事Oracle/MySQL相关工作,现任易极付高级DBA,Orac..._深入理解mysql主从原理32讲

推荐文章

热门文章

相关标签