pytorch学习笔记本_pytorch深度学习的笔记本-程序员宅基地

技术标签: pytorch 深度学习 机器学习  

torch 下:nn、autograd 、mm、 optim
nn下: functional、Parameter、BCEWithLogitsLoss、Sequential 、Module

神经网络搭建的简单过程

#导入常用的库
import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt

#生成数据
np.random.seed(1)
m = 400
N = int(m/2)
D = 2
x = np.zeros((m,D))
y = np.zeros((m,1),dtype='uint8')
a = 4

for j in range(2):
    ix = range(N*j, N*(j+1))
    t = np.linspace(j*3.12,(j+1)*3.12,N) + np.random.randn(N)*0.2
    r = a*np.sin(4*t) + np.random.randn(N)*0.2
    x[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
    y[ix] = j

x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float()1:通过定义函数去搭建网络结构
def func_net(x):
    x1 = torch.mm(x, w1) + b1
    x1 = F.tanh(x1)
    x2 = torch.mm(x1,w2) + b2
    
    return x2

optim_1 = torch.optim.SGD([w1,w2,b1,b2], 0.01)
criterion = nn.BCEWithLogitsLoss()

for e in range(10000):
    out = func_net(Variable(x))
    loss = criterion(out, Variable(y))
    optim_1.zero_grad()
    loss.backward()
    optim_1.step()
    if (e+1)%1000 == 0:
        print('epoch:{},loss:{}'.format(e+1, loss))


法二:采用pytorch的Sequential模块搭建
seq_net = nn.Sequential(
    nn.Linear(2,4),
    nn.Tanh(),
    nn.Linear(4,1)
    

)

optim_2 = torch.optim.SGD(seq_net.parameters(), 0.01)
criterion = nn.BCEWithLogitsLoss()

for e in range(10000):
    out = seq_net(Variable(x))
    loss = criterion(out, Variable(y))
    optim_2.zero_grad()
    loss.backward()
    optim_2.step()
    if (e+1)%1000 == 0:
        print('epoch:{},loss:{}'.format(e+1,loss))

法三:采用pytorch的Module模块搭建

class module_net(nn.Module):
    def __init__(self,num1,num2,num3):
        super(module_net,self).__init__()
        self.layer1 = nn.Linear(num1,num2)
        self.layer2 = nn.Tanh()
        self.layer3 = nn.Linear(num2,num3)
    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
    
mod_net = module_net(2,4,1)

optim_3 = torch.optim.SGD(mod_net.parameters(), 0.01)
criterion = nn.BCEWithLogitsLoss()

for e in range(10000):
    out = mod_net(Variable(x))
    loss = criterion(out,Variable(y))
    optim_3.zero_grad()
    loss.backward()
    optim_3.step()
    if (e+1)%1000 == 0:
        print('epoch:{},loss:{}'.format(e+1,loss))

三种方法,法二和法三的效果会更好,这是因为 PyTorch 自带的模块比我们写的更加稳定,这2一些初始化的问题相关。

模型的访问

mod_net.layer1
》
Linear(in_features=2, out_features=4, bias=True)


mod_net.layer1.weigtht
》
Parameter containing:
tensor([[-0.0118, -1.9755],
        [ 1.3713,  1.3738],
        [ 1.4056, -1.5453],
        [-0.0824, -0.0446]], requires_grad=True)

for i in mod_net.children():
	print(i)
》
Linear(in_features=2, out_features=4, bias=True)
Tanh()
Linear(in_features=4, out_features=1, bias=True)


for i in mod_net.modules():
	print(i)
》
module_net(
  (layer1): Linear(in_features=2, out_features=4, bias=True)
  (layer2): Tanh()
  (layer3): Linear(in_features=4, out_features=1, bias=True)
)
Linear(in_features=2, out_features=4, bias=True)
Tanh()
Linear(in_features=4, out_features=1, bias=True)

模型的保存和读取

1)将参数和模型保存在一起
torch.save(seq_net, 'save_seq_net.pth')
net_ = torch.load('save_seq_net.pth')
#读取
net_.layer1.weight
》
Parameter containing:
tensor([[-0.0118, -1.9755],
        [ 1.3713,  1.3738],
        [ 1.4056, -1.5453],
        [-0.0824, -0.0446]], requires_grad=True)
2)只保存参数
torch.save(seq_net.state_dict,'save_seq_net.pth')
seq_net2 = nn.Sequential(
    nn.Linear(2, 4),
    nn.Tanh(),
    nn.Linear(4, 1)
)

seq_net2.load_state_dict(torch.load('save_seq_net_params.pth'))

一般推荐使用第二种,可移植性更强

参数初始化
Xavier初始化法
来源文献:http://proceedings.mlr.press/v9/glorot10a.html
公式:
ω = : U n i f o r m ( − 6 n j + n j + 1 , 6 n j + n j + 1 ) \omega =: Uniform(-\frac{\sqrt{6}}{\sqrt{n_{j}+n_{j+1}}},\frac{\sqrt{6}}{\sqrt{n_{j}+n_{j+1}}}) ω=Uniform(nj+nj+1 6 ,nj+nj+1 6 )
n j n_{j} nj n j + 1 n_{j+1} nj+1分别是该层的输入和输出数目。

for layer in mod_net.modules():
    if isinstance(layer, nn.Linear):
        nn.init.xavier_uniform_(layer.weight(), gain=nn.init.calculate_gain('relu'))
    
#一般初始化方法:
for layer in net2.modules():
    if isinstance(layer, nn.Linear):
        param_shape = layer.weight.shape
        layer.weight.data = torch.from_numpy(np.random.normal(0, 0.5, size=param_shape)) 
#gain用于设置初始化参数的标准差来匹配特定的激活函数

batch_size
batch_size越大梯度具有越高的随机性,batch_size越小梯度越稳定

基于梯度的优化算法
1)torch.optim.SGD()
2)torch.optim.SGD(momentum=0.9):相当于每次在进行参数更新的时候,都会将之前的速度考虑进来,每个参数在各方向上的移动幅度不仅取决于当前的梯度,还取决于过去各个梯度在各个方向上是否一致,如果一个梯度一直沿着当前方向进行更新,那么每次更新的幅度就越来越大,如果一个梯度在一个方向上不断变化,那么其更新幅度就会被衰减,这样我们就可以使用一个较大的学习率,使得收敛更快,同时梯度比较大的方向就会因为动量的关系每次更新的幅度减少
3)torch.optim.Adagrad():改进动量法和随机梯度下降对任何参数都使用相同的学习率的情况,但随着梯度平方的不断累加,学习率会越来越小,导致收敛乏力
4)torch.optim.RMSprop():前面我们提到了 Adagrad 算法有一个问题,就是学习率分母上的变量 s 不断被累加增大,最后会导致学习率除以一个比较大的数之后变得非常小,这不利于我们找到最后的最优解,所以 RMSProp 的提出就是为了解决这个问题。
5)torch.optim.Adadelta():Adadelta 算是 Adagrad 法的延伸,它跟 RMSProp 一样,都是为了解决 Adagrad 中学习率不断减小的问题,RMSProp 是通过移动加权平均的方式,而 Adadelta 也是一种方法,有趣的是,它并不需要学习率这个参数。
6)torch.optim.Adam():Adam 是一个结合了动量法和 RMSProp 的优化算法,其结合了两者的优点。
几种梯度下降算法的详细介绍

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_41802245/article/details/107220924

智能推荐

Rggplot2_下面关于ggplot2中,叙述错误的是a.ggplot2必须在rstidio平台上运行b.ggplo-程序员宅基地

文章浏览阅读1.2k次。文章转载自:https://www.cnblogs.com/nxld/p/6059603.html分析数据要做的第一件事情,就是观察它。对于每个变量,哪些值是最常见的?值域是大是小?是否有异常观测?ggplot2图形之基本语法:ggplot2的核心理念是将绘图与数据分离,数据相关的绘图与数据无关的绘图分离ggplot2是按图层作图ggplot2保有命令式作图的调整函数,使其更具灵活..._下面关于ggplot2中,叙述错误的是a.ggplot2必须在rstidio平台上运行b.ggplot2

ORB-SLAM3 ROS 运行_orbslam3 ros-程序员宅基地

文章浏览阅读9.7k次,点赞7次,收藏116次。为单眼,单眼+惯性,立体视觉,立体+惯性和RGB-D构建节点环境为:ROS Melodic 和 Ubuntu 18.04编译1、将源码中的 Examples/ROS/ORB_SLAM3 路径添加到ROS_PACKAGE_PATH环境变量中打开 .bashrc file:gedit ~/.bashrc把下面这行加到最下面一行,把“PATH”换成你放ORB_SLAM3的路径export ROS_PACKAGE_PATH=${ROS_PACKAGE_PATH}:“PATH”/ORB_SLAM3/_orbslam3 ros

sql优化-程序员宅基地

文章浏览阅读43次。1、带or的sqlunionunion all2、where中带max()的sqlnot exists

云南计算机专升本数据结构_云南专升本计算机专业考试科目有哪些?-程序员宅基地

文章浏览阅读720次。同学们应该要知道云南专升本计算机专业考试科目是哪些,毕竟关系到未来考试,根据公布的云南专升本政策来看云南专升本计算机专业考试科目是高等数学和公共英语和数据结构,下面跟随易学仕专升本网来看看吧!一、云南专升本可以报考哪些学校?云南可以报考的院校有很多,下面就和易学仕一起来看看云南专升本学校名单有哪些吧!云南师范大学,云南艺术学院,云南民族大学,西南林业大学,云南农业大学,昆明理工大学,楚雄师范学院,..._云南专升本数据结构是考什么

论文导读:实时语义分割网络BiSeNetV1和v2_语义分割bisenetv2-程序员宅基地

文章浏览阅读4.4k次,点赞2次,收藏27次。文章目录一、背景二、BiSeNetV1三、BiSeNetV2v1论文地址:BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentationv2论文地址:BiSeNet V2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation一、背景低水平的细节特征(spatial information 空间信息)和高水平_语义分割bisenetv2

Python----一个对象的属性可以是另外一个类型创建的对象_python 一个对象的属性是另一个对象list-程序员宅基地

文章浏览阅读3.9k次,点赞2次,收藏6次。del 生命周期结束, 可以删除一个对象class Cat: """这是一个猫类""" def __init__(self, name): print("初始化开始...") self.name = name def eat(self): print("%s eat fish" % self.name) def drink(s_python 一个对象的属性是另一个对象list

随便推点

python整数溢出问题_Python多个整数溢出漏洞-程序员宅基地

文章浏览阅读419次。BUGTRAQ ID: 30491CVE(CAN) ID: CVE-2008-2315,CVE-2008-2316,CVE-2008-3142,CVE-2008-3143,CVE-2008-3144Python是一种开放源代码的脚本编程语言。Python中存在多个整数溢出漏洞,可能允许恶意用户导致拒绝服务或入侵有漏洞的系统。1) stringobject、unicodeobject、buffer..._该漏洞可以通过人整数转换触发 python core 的 过载,以触发拒绝服务

python读取txt数据写入excel_python读txt写入excel-程序员宅基地

文章浏览阅读8.8k次,点赞3次,收藏22次。在公司接到一个任务,从txt中抓取数据写入excel,txt格式固定,并且有多个txt文件先安装excel的读写支持,参考:https://www.cnblogs.com/cllovewxq/p/5363636.html就是下载xlrd和xlwt,进入该目录分别运行python setup.py install,这个程序只用到写入操作--xlwt思路: 逐个打开txt文件,抓_python读txt写入excel

centos7镜像在虚拟机上安装centos7详细教程_虚拟机安装centos7安装教程详细-程序员宅基地

文章浏览阅读7.8k次,点赞4次,收藏15次。有许多人在安装虚拟机这方面不会操作,在安装过程中经常出现问题,所以今天出一期下载安装虚拟机的教程一、前期安装准备二、下载centos7镜像文件三、安装centos7四、打开VMware第一步:在VMware新建一个虚拟机第二步:..._虚拟机安装centos7安装教程详细

实现一个java版本的redis(1)——实现一个内存型KV存储_java实现kv数据库-程序员宅基地

文章浏览阅读657次。前排说一下,这是一个十分简陋的KV内存数据库,作为笔者实现redis的第一章,大佬可以走了,因为真的很简陋。仅供学习。心血来潮,看到了开源项目godis,但自己对go又没有很熟悉,一开始去看了godis,一头雾水,索性想到为什么不用java来实现一个redis呢?说干就干​ 第一步,我们来实现一个简单的运行在单机的内存型的KV数据库,严格来说这不是redis,和redis差了十万八千里。就是将一个字典,通过网络的方式提供了出去。但毕竟第一步,我们就来实现一个简单一点的(十分的简陋)。​ 我们主要来实_java实现kv数据库

URL处理几个关键的函数parse_url、parse_str_httpurl.parse(url);-程序员宅基地

文章浏览阅读745次。parse_url()该函数可以解析 URL,返回其组成部分。它的用法如下:array parse_url(string $url)此函数返回一个关联数组,包含现有 URL 的各种组成部分。如果缺少了其中的某一个,则不会为这个组成部分创建数组项。组成部分为:scheme - 如 http host - 如 localhostport - 如 80user pass _httpurl.parse(url);

推荐好用的Linux远程连接工具_linux最好的远程工具-程序员宅基地

文章浏览阅读3.1k次,点赞2次,收藏11次。在连接linux的时候用了很多工具:**Xshell** **SecureCRT** **Putty** **FinalShell** **MobaXterm**,还有很多其他的其中Xshell SecureCRT都不是免费的,当然有办法破解,这个在网上一大堆,就不再说了,putty非常轻量级也很好用,也是最常用的,但是我在此推荐两款更好用的,国产的FinalShell和MobaXterm;一、FinalShell网址:http://www.hostbuf.com/这个软件很强大可以实时查看c_linux最好的远程工具