[Pytorch系列-51]:循环神经网络RNN - torch.nn.RNN类的参数详解与代码示例_文火冰糖的硅基工坊的博客-程序员秘密

技术标签: 循环神经网络  RNN  人工智能-PyTorch  深度学习  人工智能  pytorch  

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_程序员秘密

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121505015


目录

第1章 RNN神经网络的理论基础

第2章 torch.nn.RNN类

2.1 原型

2.2 案例

2.3 解读

第3章 forward前向运算的输入与输出

3.1 输入输出原型

3.2 输入解读

3.3 输出解读

第4章 代码示例:Pytorch预定义RNN网络



第1章 RNN神经网络的理论基础

https://blog.csdn.net/HiWangWenBing/article/details/121387285https://blog.csdn.net/HiWangWenBing/article/details/121387285

第2章 torch.nn.RNN类

2.1 原型

2.2 案例

2.3 解读

(1)input_size输入样本的向量长度

假如在NLP中,需要把一个单词输入到RNN中,而这个单词的向量化编码是300维的,那么这个input_size就是300。 input_size其实就是规定了输入样本的维度。用f(wX+b)来类比的话,这里输入的就是X的维度。

在上图案例中,input_size = 2。

(2)hidden_size:“时序”单元中隐藏层的输出特征的向量长度

在“时序”单元中,隐藏层可以输出多个属性,hidden_size就是定义了属性的size。

在上图案例中,hidden_size = 3。

 (3)num_layers: 隐藏层堆叠的层数

 在上图案例中,num_layers= 1。

(4)nonlinearity:定义激活函数

nonlinearity == 'tanh'或'relu':

(5)bias:是否需要偏置项

(6)batch_first:定义

定义了输入参数的形状,pytorch与Numpy的定义不同。

batch_first==True:               (batch, seq, feature)
batch_first==False(torch默认): (seq, batch, feature)

torch之所以把seq放在最前面,这是因为在时序逻辑处理中,通常会优先批量读取多个序列数据。

(7)dropout: 是否需要dropout隐层的输出

在RNN单元串联或堆叠的网络中,会出现梯度消失或梯度爆炸的情形。通过随机的dropout一些隐层的输出,可以降低梯度消失的情形。

(8)bidirectional:是否为双向RNN网络

 双向网络中,两个方向上,各自有独立的权重矩阵和状态张量。

第3章 forward前向运算的输入与输出

3.1 输入输出原型

3.2 输入解读

  • input:用于存放输入样本的张量,张量的形状如下:

math:`(L, N, H_{in})` when ``batch_first=False`` or :

math:`(N, L, H_{in})` when ``batch_first=True`

N:batch size, 一次可以送个一个batch的数据,batch size描述的可以同时并行输入的序列串的个数。
L:sequence length,连续多个输入样本,一次性送入RNN网络或foward函数中,RNN会依次输出sequence length批次的输出。sequence length可以串行输入序列的个数。
H_{in} :input_size
H_{out} :hidden_size
  • h_o: 用于存放RNN初始的隐藏状态,通常为上一时刻预测时隐层状态的输出,如果没有上一时刻,这设置全0.

3.3 输出解读

  • output:RNN网络的输出
  • h_n:    RNN网络隐层的输出

 

 hn就是RNN的最后一个隐含状态。

第4章 代码示例:Pytorch预定义RNN网络

(1)环境准备

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库
import time as time

import torch             # torch基础库
import torch.nn as nn    # torch神经网络库
import torch.nn.functional as F
from torch.autograd import Variable

import torchnlp
from torchnlp.word_to_vector import GloVe
# from torchnlp.word_to_vector import Glove

print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
print(torch.backends.cudnn.version())
Hello World
1.10.0
True
10.2
7605

(2)定义一个序列化输入

# 定义输入序列的长度
seq_len = 3

#定义batch的长度
batch_size = 1

# input_size: 输入特征的数量/维
input_size  = 2

# 定义输入样本
input = torch.randn(seq_len, batch_size, input_size)
print(input.shape)
print(input)
torch.Size([3, 1, 2])
tensor([[[-0.1100,  0.7480]],

        [[-0.0672, -1.2424]],

        [[ 0.1957,  0.4846]]])

备注:

输入是一个样本输入特征的长度为2, 样本的序列化串的长度为3。因此输入shape=3*2。

比如文本:“I love you”,就 是一个序列,由三个单词组成,不管单词由多少个字母组成,每个单词被编码成一个长度=2的向量。

(3)定义RNN网络

# 定义RNN网络
# input_size: 输入特征的数量/维

# hidden_size(横向):隐藏层的size(不是层数),即输出序列的长度,也是输入序列的长度
hidden_size = 3
# num_layers (纵向): 隐藏层的堆叠层数
num_layers = 1

# 定义H0的状态
h0 = torch.zeros(num_layers, batch_size, hidden_size) 
print(h0.shape)
torch.Size([1, 1, 3])

备注:用于保存隐藏层状态的张量,其shape,取决于RNN网络的结构。

在这里,RNN网络为单层,即num_layers = 1,每一层提取的隐藏层特征为3,即hidden_size = 3

因此,隐藏层的特征输出为1 * 3, 加上batch就是1 * 1 * 3

注意的是:在这里,batch number是放在中间,网络的层数是放在首位的。

# 定义卷积神经网络
rnn = torch.nn.RNN(input_size= input_size, hidden_size = hidden_size, num_layers = num_layers, bias=True, bidirectional=False)
print(rnn) 

#显示神经网络的参数
# 备注:
# 输入权重矩阵weight_hh_l0是 hidden_size * input_size ? 不是说权重共享的吗?
# 隐层权重矩阵weight_hh_l0是 hidden_size * hidden_size ?

print(rnn.parameters)
for key, value in rnn.state_dict().items():
    print(key)
    print(value)
RNN(2, 3)
<bound method Module.parameters of RNN(2, 3)>
weight_ih_l0
tensor([[-0.2329, -0.3762],
        [-0.2569, -0.3311],
        [-0.4779,  0.2690]])
weight_hh_l0
tensor([[ 0.1738, -0.4826,  0.2800],
        [-0.5386, -0.0141, -0.2643],
        [ 0.5265,  0.1451, -0.1797]])
bias_ih_l0
tensor([ 0.1088, -0.0520,  0.0557])
bias_hh_l0
tensor([-0.1980, -0.3961, -0.2472])

备注:

ih:输入到隐藏层的权重矩阵=2 * 3, 这是因为这里定义的隐藏层的输入特征=2, 输出特征为3.

hh:隐藏状态到隐藏层的权重矩阵=3 * 3, 这是因为这里定义的隐藏层的状态特征=3, 输出特征为3,他们比如是相等的。

(4)网络输出:单输入,单输出

# 单个单词输入
# 0输出矩阵 = seq_len *  hidden_size,
# h输出矩阵 =    1    *  hidden_size 
input_single = input[0]
input_single = input_single.reshape(1, batch_size, input_size)

output, h = rnn(input_single, h0)

print("input:单输入")
print(input_single.shape)
print(input_single)

print("\noutput:单输出")
print(output.shape)
print(output)



# h是最后的输出
print("\nhiden:隐藏状态输出")
print(h.shape)
print(h)
input:单输入
torch.Size([1, 1, 2])
tensor([[[0.0749, 0.5568]]])

output:单输出
torch.Size([1, 1, 3])
tensor([[[-0.3060, -0.5728, -0.0773]]], grad_fn=<StackBackward0>)

hiden:隐藏状态输出
torch.Size([1, 1, 3])
tensor([[[-0.3060, -0.5728, -0.0773]]], grad_fn=<StackBackward0>)

(4)网络输出:序列输入,序列输出

# 序列输入(多个单词组成序列)
# 0输出矩阵 = seq_len *  hidden_size,
# h输出矩阵 =    1    *  hidden_size 
print("input:序列输入")
print(input.shape)
print(input)

output, h = rnn(input, h0)
print("\noutput:序列输出")
print(output.shape)
print(output)

# h是最后的输出
print("\nhiden:隐藏状态输出")
print(h.shape)
print(h)
input:序列输入
torch.Size([3, 1, 2])
tensor([[[ 0.0749,  0.5568]],

        [[ 0.2010, -0.1058]],

        [[-1.7882,  1.2671]]])

output:序列输出
torch.Size([3, 1, 3])
tensor([[[-0.3060, -0.5728, -0.0773]],

        [[ 0.1050, -0.2649, -0.4977]],

        [[-0.1418, -0.3181,  0.8042]]], grad_fn=<StackBackward0>)

hiden:隐藏状态输出
torch.Size([1, 1, 3])
tensor([[[-0.1418, -0.3181,  0.8042]]], grad_fn=<StackBackward0>)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_程序员秘密

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121505015

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/HiWangWenBing/article/details/121505015

智能推荐

YUV基础知识(转)_请叫我小黑的博客-程序员秘密

四,视频编码4.1,颜色空间YCbCr色彩空间和它的变形(有时被称为YUV)是最常用的有效的表示彩色图像的方法。Y是图像的亮度(luminance/luma)分量,使用以下公式计算,为R,G,B分量的加权平均值:Y = kr R + kgG + kbB其中k是权重因数。其中每个色差分量为R,G,B值和亮度Y的差值:Cb = B -Y(注释:蓝色的色差)Cr = R -YCg = G- Y其中,Cb+Cr+Cg是一个常数(其实是一个关于Y的表达式),所以,只需要其中两个数值结合Y值就能够计算

l7sa008b故障代码_韩国麦克比恩电机-L7S 中文说明书 ver1.9.pdf_weixin_39853131的博客-程序员秘密

电工电子,电器,安装,维修,电磁炉,线路图,电路图,元器件,监测,故障,行业报告,行业分析Ver 1.9前言前言非常感谢选用乐星迈克彼恩的L7系列产品。本使用手册详解产品的使用方法以及注意事项。不正确的操作会引起产品安全事故或导致产品的损坏,使用之前请务必阅读本使用手册,正确操作。• 此手册根据软件版本的不同,可能会更新部分内容。• 此手册的任何部分在未经乐星迈克彼恩书面认可之前,不得以任何形式进...

python基础教程百度云-python从入门到精通视频百度云盘下载_编程大乐趣的博客-程序员秘密

python入门教程-9-Python编程风格(1).zippython入门教程-8-Python编程语言基础技术框架(4)之函数介绍.zippython入门教程-7-Python编程语言基础技术框架(3)之print输出.zippython入门教程-6-Python编程语言基础技术框架(2).zippython入门教程-60-Python执行环境及doctes.zippython入门教程-5-P...

HDU3764(Cyclic Nacklace)_;w the endhdu_Splaying的博客-程序员秘密

Cyclic NacklaceProblem DescriptionCC always becomes very depressed at the end of this month, he has checked his credit card yesterday, without any surprise, there are only 99.9 yuan left. he is too distressed and thinking about how to tide over the last

dockers镜像的制作_李氏程序员的博客-程序员秘密

[[email protected] ~]# docker imagesREPOSITORY TAG IMAGE ID CREATED SIZEcentos latest 75835a67d134 3 month...

随便推点

以太网帧、IP 帧、UDP/TCP帧、http 报文结构解析_tcp帧结构_程序员自我修养的博客-程序员秘密

我们从 OSI/RM 参考模型入手,来看 OSI/RM 七层模型中的每一层数据帧结构。一 OSI/RM 结构OSI 是不同制造商的设备和应用软件在网络中进行通信的标准,此模型已经成为计算机间和网络间进行通信的主要结构模型, 目前使用的大多数网络通信协议的结构都是基于 OSI 模型的。OSI 包括体系结构、服务定义和协议规范三级抽象。OSI 体系结构定义了一个七层模型用于进行进程间的通信,并...

深度学习从入门到精通——人工智能、机器学习与深度学习绪论_杨乐坤深度学习_小陈phd的博客-程序员秘密

人工智能、机器学习与深度学习人工智能定义人工智能历史机器学习分类,按照监督方式深度学习主要应用数学基础张量基本知识矩阵的秩:矩阵的逆矩阵的广义逆矩阵矩阵分解矩阵特征分解常见的概率分布二项分布均匀分布高斯分布 ***指数分布多变量概率分布条件概率联合概率条件概率和联合概率的关系贝叶斯公式(重点掌握)常用统计量信息论熵(Entropy)联合熵条件熵互信息KL散度,相对熵交叉熵最小二乘估计人工智能定义个人浅定义:让计算机机器模拟能到学习跟人一样掌握某种技能。类似于弱人工智能。如果让计算机跟人类一样具有自主

PHP代码审计中你不知道的牛叉技术点_代码审计技术点_Mamba start的博客-程序员秘密

一、前言 &nbsp;&nbsp;&nbsp; php代码审计如字面意思,对php源代码进行审查,理解代码的逻辑,发现其中的安全漏洞。如审计代码中是否存在sql注入,则检查代码中sql语句到数据库的传输 和调用过程。 入门php代码审计实际并无什么门槛要求,只需要理解基础的php语法规则,以及理解各种类型漏洞的出现原因则可以开始尝试审计php源代码。通常的漏洞演示中sql语句会直接传入php自带的函...

STM32的I2C通信实例pcf8591(AD/DA)原创干货_别把我的消息带回家乡的博客-程序员秘密

本程序是stm32f103 接pcf8591AD/DA板(YL-PCF8591),I2C通信,I2C通信用的是模拟I2C,不是STM32 自带的硬件I2C外设。实践检验很稳定。SCL接口用PB6,SDA用PB7,与STM32自带的硬件I2C接口针脚是一样的,这是巧合,你也可以随便定义成别Pin针脚。关于GPIO设置,最后有解释。Pcf8591地址写是0x90,读是0x91, AIN0地址是0x40 -------AIN3地址是0x43,注意AOUT地址也是0x40写比较...

云南农业大学C语言程序设计,云南农业大学341农业知识综合三考研真题笔记期末题等..._不死鹰阿江的博客-程序员秘密

2017年云南农业大学341农业知识综合三全套考研资料第一部分考研历年真题(非常重要)1-1(这一项是发送电子版)本套资料没有真题。赠送《重点名校近三年考研真题汇编》供参考。说明:不同院校真题相似性极高,甚至部分考题完全相同。赠送《重点名校考研真题汇编》供参考。第二部分2017年专业课笔记、讲义资料(基础强化必备)2-1《理论力学》考研复习笔记。说明:高分研究生复习使用,条理清晰,重难...

win10cmd上传文件到linux,linux之间和linux与windows之间用命令行进行文件传输_Emmamkq的博客-程序员秘密

linux之间传送单个文件:scp 文件 [email protected]:上传要存的路径传送文件夹:scp -r 文件夹 [email protected]:上传要存的路径下载单个文件:scp [email protected]:文件 本地要存的路径下载文件夹:scp -r [email protected]:文件夹 本地要存的路径带密码操作:在上面语句前面增加 sshpass -p '密码’linu...

推荐文章

热门文章

相关标签