第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)_imdb数据集rnn分类案例tensorflow-程序员宅基地

技术标签: 少奶奶的深度学习指北  RNN循环神经网络  IMDB下载  RNN梯度  Tensorflow2.0  

一、时序数据

卷积神经网络是针对二维位置相关的图片,它采取权值共享的思想,使用一个滑动的窗口去提取每一个位置相关的信息,自然界除了位置相关的信息以外,还存在另一种跟时序相关的数据类型,例如:序列信号,语音信号。对于按时间轴不停产生的信号,神经网络中,我们称其为temporal signals,而循环神经网络处理的就是拥有时序相关或者前后相关的数据类型(Sequence)。

二、embedding数据类型转换

计算机是无法处理时序数据的,所以,我们必须把这些时序数据转化成数值类型,以文本为例,文字和数值之间并没有一一对应的关系,而且,文字之间还有语义关系,没有大小关系,要是我们能找到一种表示方法能很好的表达文字之间的语义关系,那么用于训练或者分类就会很有帮助,而我们把这一转化过程统称为embedding。

在循环神经网络中,我们通常用下面的格式进行表示训练数据:

                                                                  [ b, seq_len, fearture_len ]

b:一个批次的大小。seq_len:时序数据的长度(一句话有seq_len个单词,一段录音有seq_len时刻等)。fearture_len:对于每一个seq_len,我们用多少纬度来表示(一个单词,我们可以用长度为4的向量表示)

例如: 

我喜欢你  =====》[ 1, 4, 4 ]: 只有一句话,所以b=1,这句话有4个汉字,所以seq_len=4,每个汉字,我们用一个4纬的向量表示,所以,fearture_len=4。同理,对于一段长达10秒的语音信号,我们可以这样表示为 [ 1, 10, 10 ]。

备注:在本篇博文中,我们只讨论文本处理即word embedding,对于语言序列不给予过多讨论。

三、word embedding

文本的转换不仅仅是逻辑上的转换,或者简单的一对一关系,其中还要包括语义上的转换。所以,word embedding需要满足一下两个条件:

1)语义相关。

单词在语义空间的相似度需要在数值空间表现出来(可利用位置距离或者概率表示),例如:输入king这个单词的话,那么在语义空间中和king相关的所有词都必须在数值空间中,与king的位置靠近(距离的远近代表着相关的程度)。

2)可训练。

下面,我们将利用layers简单的实现一次word embedding过程:

数据:I love it   ====》[ 1, 3, 4 ]

第一步:为三个单词设置一个编号。

第二步:调用layers的embedding接口,创建一个embedding层,并设置这个层的纬度。

第三步:把x输入embedding层中,embedding层就会为x随机初始化一个tensor对象

当前的embedding结果是不符合要求的,我们会通过后续训练去优化embedding。

四、循环神经网络

通过word embedding,我们可以把句子中的所有单词数值化后,映射到向量空间中去,为了得到最佳的word embedding结果,我们会把数值化的单词输入全连接网络中进行训练,具体如下:

                                         

每一个单词都输入到单独的全连接层中,得到的结果用于对情感的划分,但是,当句子过长时,其网络的层数和参数都会成直线上升,而且,当我们变换句子顺序时,其输出结果不变,即没有保留上下之间的语境信息。

针对参数上述问题,我们提出权重共享的思想,即每一个全连接层共享同一个w和b:

                                  

针对语义信息,由于现有网络没有全局的综合的高层的语义信息抽取过程,所以,我们需要一个全局的数据去存储整个句子的语义信息,即需要一个容器去存储从第一个单词到最后一个单词的总体语义信息,该容器被称为memory。

 

                                                        如上图所示,h0为初始memory,h1~h5是每一层输出的memory,后一层利用下面的公式会用到前一层的语境信息,

                                                                            h_2=x@w_x + h_1@w_h_h

所以最后一层就拥有了全局的语境信息,最后,我们可以利用这个全局的语境信息去做分类。

 通过观察上图中的全连接层,我们发现每一个全连接层的结构都是一样的(都是x@w_x + h_t@w_h_h形式),所以,我们把这些全连接层折叠起来就成了RNN循环神经网络了。具体公式如下:

                                               

备注:在RNN中,激活函数通常用thah(),ht为t时刻的memory,yt表示RNN在循环完毕后,利用一个全连接层汇总RNN中提取的高维特征,用于最终结果的分类。

五、RNN的可训练性

所有神经网络结构,都需要能进行反向梯度更新,那么,RNN是怎么求解梯度的呢?首先,我们把RNN网络按时间轴进行展开:

                                                         

RNN公式:           h_t =tanh(W_I_x +W_Rh_t_-_1)                        y_t=W_0h_t

loss函数:            losee = MSE(y,x)

我们对W_R进行求导:

                                                     \frac{\partial E_t}{\partial W_R} = \sum_{i=0}^{t}{\frac{\partial E_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}\frac{\partial h_t}{\partial h_i}\frac{\partial h_i}{\partial W_R}}

假设t=1,则:

                                        \frac{\partial E_1}{\partial W_R} = \frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial h_1}\frac{\partial h_1}{\partial h_0}\frac{\partial h_0}{\partial W_R}

假设t=2,则:

                                      \frac{\partial E_2}{\partial W_R} =\frac{\partial E_1}{\partial y_1}\frac{\partial y_1}{\partial h_1}\frac{\partial h_1}{\partial h_0}\frac{\partial h_0}{\partial W_R} + \frac{\partial E_2}{\partial y_2}\frac{\partial y_2}{\partial h_2}\frac{\partial h_2}{\partial h_1}\frac{\partial h_1}{\partial h_0}\frac{\partial h_0}{\partial W_R}

我们可以发现,由于\frac{\partial E_t}{\partial y_t},\frac{\partial y_t}{\partial h_t},\frac{\partial h_i}{\partial W_R}是已知的,所以可以直接求解导数,但是\frac{\partial h_t}{\partial h_i}是未知的,我们把\frac{\partial h_t}{\partial h_i}进行展开可得:

                                                           \frac{\partial h_t}{\partial h_i} = \frac{\partial h_t}{\partial h_t_{-1}}\frac{\partial h_t_{-1}}{\partial h_t_{-2}}.....\frac{\partial h_i_{+1}}{\partial h_i}=\prod ^{t-1}_{k-i}\frac{\partial h_k_{+1}}{\partial h_k}

利用RNN公式可得:

                                                            \frac{\partial h_k_{+1}}{\partial h_k}=diag(f(W_I_{xi}+W_R_{h_i-1})) W_R

则求1到k时刻的梯度为:

                                                             \frac{\partial h_k}{\partial h_1} =\prod ^{k}_{i}diag(f(W_I_{xi}+W_R_{h_i-1})) W_R

所以,\frac{\partial h_t}{\partial h_i}也是可求的,这里提示一下,k时刻会得到{W_R}的k次方,这一项会导致RNN很容易出现梯度弥散或者梯度爆炸。

六、RNN训练时数据纬度的变化

本小节,博主将详细叙述在一次word embedding的循环网络训练过程中,其内部数据纬度的变化。

第一步:输入一个批次的训练数据集  [ b, seq_len, fearture_len ] = [ 128, 80, 100 ],由于seq_len是句子的长度,而一次RNN网络处理的是一个单词,所以,我们按seq_len进行展开,则数据纬度变成:[ b, fearture_len ] = [ 128, 100 ]。

第二步:根据公式h_2=x@w_x + h_1@w_h_h可得:

                        h = [b,feature_len]@[fearture_len, hidden_len] + [b, hidden_len]@[bidden_len, hidden_len]

则[fearture_len, hidden_len] = [ 100, 64],[b, hidden_len]=[128, 64],[bidden_len, hidden_len]=[64,64],这样我们就可以把[b, 100]纬度的数据降纬到[ b, 64 ]。

实战部分

七、IMDB数据集简介

 IMDB数据集包含了50000条偏向明显的电影评论,其中25000条作为训练集,25000作为测试集,其标签只有pos/neg两种,属于二分类问题,这里给出网上收集到的百度云下载链接:

链接:https://pan.baidu.com/s/1jcAZiGy0zeo9VjUKBDLZHA    

提取码:3aka 

八、代码讲解

本实战分部分主要是使用两层的SimpleRNNCell实现对IMDB数据集的训练。在下一章里,我们将使用层的方式和LSTM以及GRU方式实现对IMDB数据集的训练。之所以选择使用SimpleRNNCell来实现,主要是为了向大家讲解RNN的原理,希望能帮助到大家。

第一行代码:利用Tensorflow提供的接口加载数据集,在加载的同时,设置网络识别单词的最大个数(num_words),因为英语单词至少有5万多个,网络不可能把所有单词都记住,而且我们日常使用的也很少,所以,当设置total_words=10000时,代表的意思是,网络只对10000个不同单词进行编码。对于超过10000的单词用同一个特殊符号表示。

后两行代码:设置句子的长度,因为,句子的长度会直接影响到RNN的循环次数,所以,为了训练方便,我们会把句子进行截断成相同长度,例如:maxlen=80,代表着训练数据中的句子都会变成80个单词,多的舍去,少的补零。

构建dataset对象,并设置每一个Batch的大小,drop_remainder=True的意思是当数据不足以构成一个Batch时,舍去。

 

下面是模型构建中的代码

由于我们使用的是两层RNN,每一层都需要一个初始memory。units表示RNN影藏层的纬度,本次实战设置为64,

构建两层简单的Cell,并启用dropout功能。

这是RNN循环完毕后的全连接层,输出一个结果。下面给出自定义层的完整代码

把输入的训练数据进行一次embedding操作,并设置两个memory的初始状态。

把seq_len纬度进行展开,把结果传入SimpleRNNCell中,得到每一层的输出值和memory。

把循环后的结果放入全连接层中,并利用sigmoid函数进行缩放。下面给出自定义模型的完整代码

下面给出整体代码:

import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import time

os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'

total_words = 10000
max_review_len = 80
batchsz = 128
embedding_len = 100

(x_train, y_train) , (x_test, y_test) = keras.datasets.imdb.load_data(num_words = total_words)
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen = max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen = max_review_len)

db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.shuffle(1000).batch(batchsz, drop_remainder=True)

print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))
print('x test shape:', x_test.shape)

class MyRnn(keras.Model):

    def __init__(self, units):
        super(MyRnn,self).__init__()

        self.state0 = [tf.zeros([batchsz, units])]
        self.state1 = [tf.zeros([batchsz, units])]

        self.embedding = layers.Embedding(total_words, embedding_len,
                                          input_length = max_review_len)

        self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.2)
        self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.2)

        self.outlayer = layers.Dense(1)

    def call(self, inputs, training=None):

        x = self.embedding(inputs)

        state0 = self.state0
        state1 = self.state1
        for word in tf.unstack(x, axis=1):
            out0, state0 = self.rnn_cell0(word, state0, training)
            out1, state1 = self.rnn_cell1(out0, state1)

        x = self.outlayer(out1)

        prob = tf.sigmoid(x)

        return prob

def main():
    units = 64
    epoch = 4
    start_time = time.time()
    model = MyRnn(units)
    model.compile(optimizer= keras.optimizers.Adam(0.001),
                  loss=tf.losses.BinaryCrossentropy(),
                  metrics = ['accuracy'])
    model.fit(db_train, epochs= epoch, validation_data = db_test)
    model.evaluate(db_test)
    end_time = time.time()
    print('all time: ' ,end_time - start_time)
if __name__ == '__main__':
    main()

在下一章中,少奶奶将继续讲解RNN网络的不足和改进,依旧是理论+实战,欢迎大家观看

开篇:开启Tensorflow 2.0时代

第一章:Tensorflow 2.0 实现简单的线性回归模型(理论+实践)

第二章:Tensorflow 2.0 手写全连接MNIST数据集(理论+实战)

第三章:Tensorflow 2.0 利用高级接口实现对cifar10 数据集的全连接(理论+实战实现)

第四章:Tensorflow 2.0 实现自定义层和自定义模型的编写并实现cifar10 的全连接网络(理论+实战)

第五章:Tensorflow 2.0 利用十三层卷积神经网络实现cifar 100训练(理论+实战)

第六章:优化神经网络的技巧(理论)

第七章:Tensorflow2.0 RNN循环神经网络实现IMDB数据集训练(理论+实践)

第八章:Tensorflow2.0 传统RNN缺陷和LSTM网络原理(理论+实战)

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/LQ_qing/article/details/100521923

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签