tensorflow(六):RNN实现手写体识别MNIST_rnn手写体识别_科大小笨的博客-程序员秘密

技术标签: python深度学习  tensorflow  

一、RNN结构

   这是一个标准的RNN结构图,图中每个箭头代表做一次变换,也就是说箭头连接带有权值。左侧是折叠起来的样子,右侧是展开的样子,左侧中h旁边的箭头代表此结构中的“循环“体现在隐层。
   在展开结构中我们可以观察到,在标准的RNN结构中,隐层的神经元之间也是带有权值的。也就是说,随着序列的不断推进,前面的隐层将会影响后面的隐层。图中O代表输出,y代表样本给出的确定值,L代表损失函数,我们可以看到,“损失“也是随着序列的推荐而不断积累的。
   除上述特点之外,标准RNN的还有以下特点:
   1、权值共享,图中的W全是相同的,U和V也一样。
   2、每一个输入值都只与它本身的那条路线建立权连接,不会和别的神经元连接。

 

  以上是RNN的标准结构,然而在实际中这一种结构并不能解决所有问题,例如我们输入为一串文字,输出为分类类别,那么输出就不需要一个序列,只需要单个输出。如图。 

 

 

  同样的,我们有时候还需要单输入但是输出为序列的情况。那么就可以使用如下结构: 

这里写图片描述

   还有一种结构是输入虽是序列,但不随着序列变化,就可以使用如下结构:

这里写图片描述

 

 二、LSTM结构

   Long Short Term 网络—— 一般就叫做 LSTM ——是一种 RNN 特殊的类型,可以学习长期依赖信息。LSTM 由Hochreiter & Schmidhuber (1997)提出,并在近期被Alex Graves进行了改良和推广。在很多问题,LSTM 都取得相当巨大的成功,并得到了广泛的使用。LSTM 通过刻意的设计来避免长期依赖问题。记住长期的信息在实践中是 LSTM 的默认行为,而非需要付出很大代价才能获得的能力。所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。

推荐博客:https://blog.csdn.net/zhaojc1995/article/details/80572098

三、tensorflow代码实现

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
#输入的图片是28*28
n_inputs=28 #输入一行,一行有28个数据
max_time=28 #一共28行,执行的次数。图像是28*28的,一行训练一次,共训练28次。
lstm_size=100 #隐层单元
n_classes=10 #10个分类
batch_size=50 #每批次50个样本
n_batch=mnist.train.num_examples // batch_size #计算一共有多少批次
#这里的none表示第一维度可以是任意的长度
x=tf.placeholder(tf.float32,[None,784])
#正确的标签
y=tf.placeholder(tf.float32,[None,10])
#初始化权值
weights=tf.Variable(tf.truncated_normal([lstm_size, n_classes], stddev=0.1))
#初始化偏执值
biases=tf.Variable(tf.constant(0.1,shape=[n_classes]))
#定义RNN网络
def RNN(X,weight,biases):
    #inputs=[batch_size,max_time,n_inputs]
    inputs=tf.reshape(X,[-1,max_time,n_inputs])
    #定义LSTM基本CELL
    lstm_cell=tf.contrib.rnn.BasicLSTMCell(lstm_size)
    outputs,final_state=tf.nn.dynamic_rnn(lstm_cell, inputs,dtype=tf.float32)
    #final_state[1]是隐藏层状态,0是细胞状态,有图可知,LSTM细胞输出有两个,一个是细胞状态,一个是隐层状态
    results=tf.nn.softmax(tf.matmul(final_state[1], weights) + biases)
    return results
#计算rnn的返回结果
prediction=RNN(x, weights, biases)
#损失函数
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
#使用AdamOptimizer进行优化
trian_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#结果存放在一个布尔型列表中
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) #argmax返回一维张量中最大的值所在的位置
#求准确率
accuarcy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #把correct_prediction变为float32类型
#初始化
init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(50):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(trian_step, feed_dict={x:batch_xs,y:batch_ys})
        acc=sess.run(accuarcy, feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print ("Iter "+str(epoch)+", Testing Accuarcy= " + str(acc))

四、运行结果

Iter 0, Testing Accuarcy= 0.6849
Iter 1, Testing Accuarcy= 0.7931
Iter 2, Testing Accuarcy= 0.8225
Iter 3, Testing Accuarcy= 0.8967

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/TuZiFaDai/article/details/107143245

智能推荐

c# 【MVC】WebApi开发实例_c# mvc模式开发例子_smartsmile2012的博客-程序员秘密

using System;using System.Collections.Generic;using System.ComponentModel.DataAnnotations;using System.Linq;using System.Web;namespace ProductStore.Models{ //商品实体类 public class Product

adb介绍、下载安装、使用[email protected]随风的博客-程序员秘密

adb下载安装及使用adb介绍:Android Debug Bridge(安卓调试桥) tools。它就是一个命令行窗口,用于通过电脑端与模拟器或者是设备之间的交互。ADB是一个C/S架构的应用程序,由三部分组成:运行在pc端的adb client:命令行程序”adb”用于从shell或脚本中运行adb命令。首先,“adb”程序尝试定位主机上的ADB服务器,如果找不到ADB服务器,“adb”程序自动启动一个ADB服务器。接下来,当设备的adbd和pc端的adb server建立连接后,ad.

Redis的持久化_as_you_like_zx的博客-程序员秘密

Redis是一个基于内存的非关系型数据库,也就是说,基于内存意味着:一旦服务器断电或者出现其他故障,存储在内存中的数据就会丢失。为了安全,Redis提供了持久化这个机制进行保障,简单来说,持久化就是将存储在内存中的数据转存到硬盘上。这样服务器重新启动时,就可以将数据从硬盘上恢复。Redis提供了两种持久化方式:RDB和AOF。RDB方式快照模式,定期把内存中当前时刻的数据保存到磁盘,这是Re...

poi生成excel文件(通用方法)_HFUTJungle的博客-程序员秘密

public class ExcelCreateUtil { private static final Logger log = LoggerFactory.getLog(ExcelCreateUtil.class); /** * 生成 excel 文件 * * @param resulsts excel 内容 * @param head...

React Native Debugger - ERROR - TurboModuleRegistry.getEnforcing(...): ‘NativeReanimated‘ 暂时解决方案_殇尘的博客-程序员秘密

问题重述最近在练习React Native的Drawer navigation时,一运行项目就报了三条错误,忘了截图错误,借用下图(图片转自github)同时查看项目的package.json,发现react-native-reanimated的依赖版本是2.0.0。{... "dependencies": {... "react-native-reanimated": "^2.0.0",... },...}解决方案(暂时)经过一番探索,怀疑是最新版本的Drawer na

DBSCAN算法理解_dbscan(n_jobs)_just_gogogo0412的博客-程序员秘密

DBSCAN算法理解1.DBSCAN简介DBSCAN(Density-Based Special Clustering of Application with Noise),它是基于密度聚类算法,密度可以理解为样本点的紧密程度,而紧密度的衡量则需要使用半径和最小样本量进行评估,如果在指定的半径内,实际样本量超过给定的最小样本量阈值,则认为是密度高的对象。DBSCAN密度聚类算法可以非常方便的发现样本集中的异常点,故通常可以使用该算法实现异常点的检测。它可以发现任何形状的样本簇,并且具有很强的抗噪声能力。

随便推点

TCP内功心法_shaukon的博客-程序员秘密

TCP内功心法网络通信靠协议,经典协议TCP。可靠传输无人比,润物无声似空气。确定源头与目的,各十六位不可弃。先传序列后确认,确认号码要加一。四位首部偏移量,6位保留无含义。六位协议状态位,最为重要无可替。紧急指针URG,需要用时方为1。确认标记标记ACK,一为有效零放弃。PSH标志推应用,置为一时表完毕...

iOS 引用外部静态库(.a文件)时或打包.a时,Category方法无法调用。崩溃,静态库所用到的第三方不打到静态库,防止导入工程文件冲突_Z苗的博客-程序员秘密

我的这个是MJRefresh,学习打.a包Terminating app due to uncaught exception ‘NSInvalidArgumentException’, reason: ‘-[UITableView setMj_footer:]: unrecognized selector sent to instance 0x7fa37a871000’结果这个问题如下设置:...

谈谈多线程编程(一)- 安全性策略_weixin_30369087的博客-程序员秘密

在多线程编程中,安全是我们考虑的最重要的因素。通常程序员都会使用锁来满足安全要求,但是只用锁并不能写出良好的多线程代码,因此我们有必要更深入一点,对线程安全策略进行更加全面的了解。首先谈谈影响线程安全的因素:影响线程安全的因素有三个因素影响到了多线程下的安全性:原子性、可见性和指令顺序 一个原子操作是单独的、不可分割的。但是高级语言中的大多数语句,包括一些简单的读写语句,...

python读取excel某一行-Python 读取csv的某行_weixin_37988176的博客-程序员秘密

站长用Python写了一个可以提取csv任一列的代码,欢迎使用。Github链接csv是Comma-Separated Values的缩写,是用文本文件形式储存的表格数据,比如如下的表格:就可以存储为csv文件,文件内容是:No.,Name,Age,Score1,Apple,12,982,Ben,13,973,Celia,14,964,Dave,15,95假设上述csv文件保存为"A.c...

Java异常总结_hit100410628的博客-程序员秘密

1.异常分类 (1)运行时异常(unchecked exception):继承自java.lang. RuntimeException类常见5种:ClassCastException(类型转换异常)IndexOutOfBoundsException(数组越界)NullPointerException(空指针)ArrayStoreException(数据存储异常,操作...

嵌入式设备开发专题《MT7688开发,wm8960音频驱动移植到LEDE17.01系统》_wm8960 dts_物联网研究室BBC的博客-程序员秘密

先说明,在openwrt和lede项目未合并之前,也就是在openwrt15.05版本,内核3.18.29,是集成wm8960驱动补丁的,当时合并之后lede17.01(内核版本4.4.124)去除了wm8960驱动补丁,所以得折腾把它加回去。widora发布的openwrt版本,内核3.18.29,是对wm8960的驱动做了小小的优化,所以选择采用widora所包含的驱动补丁进行移植。1.先把w...

推荐文章

热门文章

相关标签