tensorflow和python基于bp神经网络预测_Tensorflow 用训练好的模型预测_杨仲慈的博客-程序员秘密

技术标签: tensorflow和python基于bp神经网络预测  

本节涉及点:

从命令行参数读取需要预测的数据

从文件中读取数据进行预测

从任意字符串中读取数据进行预测

一、从命令行参数读取需要预测的数据

训练神经网络是让神经网络具备可用性,真正使用神经网络时,需要对新的输入数据进行预测,

这些输入数据 不像训练数据那样是有目标值(标准答案),而是需要通过神经网络计算来获得预测的结果。

通过命令行参数输入数据:

importnumpy as npimportsys

predictData=None

argt= sys.argv[1:]#获取命令行参数后循环判断每一个参数,并寻找是否有以“-predict=” 为开始的字符串#使用成员函数 startswith 判断是否以另一个指定的字符串开头#如果有,去掉 "-predict=" 这个前缀,只取后面剩余的字符串#tmpStr = v[len("-predict=")] 作用是让 tmpStr 等于命令行参数v 去掉开头 "-predict=" 后的字符#len() 的作用是 获得任意字符串的长度#使用 numpy包中的 fromstring 函数,把 tmpStr 中字符串转换为一个数组

for v inargt:if v.startswith("-predict="):

tmpStr= v[len("-predict="):] #注意这里使用了切片print("tmpStr: %s" %tmpStr)

predictData= np.fromstring(tmpStr, dtype=np.float32, sep=",")print("predictData: %s" % predictData)

运行结果如下:

使用 Anaconda 执行该程序:

# numpy 字符串转变为数组函数 np.fromstring(tmpStr,dtype=np.float32,sep=",")

是指将字符串 tmpStr,以字符 "," 为分隔符,转换为数组内数据项的数据类型是 float32 的数组

调用训练好的神经网络进行预测:

importtensorflow as tfimportnumpy as npimportrandomimportosimportsys

ifRestartT=False

predictData=None

argt= sys.argv[1:]for v in argt:

if v == "-restart":

ifRestartT = True

if v.startswith("-predict="):

tmpStr = v[len("-predict="):]

predictData = np.fromstring(tmpStr, dtype=np.float32, sep=",")print("predictData: %s" %predictData)

trainResultPath= "./save/idcard2"random.seed()

x=tf.placeholder(tf.float32)

yTrain=tf.placeholder(tf.float32)

w1= tf.Variable(tf.random_normal([4, 8], mean=0.5, stddev=0.1), dtype=tf.float32)

b1= tf.Variable(0, dtype=tf.float32)

xr= tf.reshape(x, [1, 4])

n1= tf.nn.tanh(tf.matmul(xr, w1) +b1)

w2= tf.Variable(tf.random_normal([8, 2], mean=0.5, stddev=0.1), dtype=tf.float32)

b2= tf.Variable(0, dtype=tf.float32)

n2= tf.matmul(n1, w2) +b2

y= tf.nn.softmax(tf.reshape(n2, [2]))

loss= tf.reduce_mean(tf.square(y -yTrain))

optimizer= tf.train.RMSPropOptimizer(0.01)

train=optimizer.minimize(loss)

sess=tf.Session()ififRestartT:print("force restart...")

sess.run(tf.global_variables_initializer())elif os.path.exists(trainResultPath + ".index"):print("loading: %s" %trainResultPath)

tf.train.Saver().restore(sess, save_path=trainResultPath)else:print("train result path not exists: %s" %trainResultPath)

sess.run(tf.global_variables_initializer())if predictData is not None:

result = sess.run([x, y], feed_dict={x: predictData})

print(result[1])

print(y.eval(session=sess, feed_dict={x: predictData})) #第二种 输出神经网络计算结果的方法,解释见下

sys.exit(0) # 终止程序

# 如果 predictData 的数据 是 “None” ,则继续训练

# 否则说明已经从命令行参数中读取了需要预测的数据,那么就调用神经网络进行预测,输出结果 结束程序

lossSum= 0.0

for i in range(5):

xDataRandom= [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10), int(random.random() * 10)]if xDataRandom[2] % 2 ==0:

yTrainDataRandom= [0, 1]else:

yTrainDataRandom= [1, 0]

result= sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})

lossSum= lossSum + float(result[len(result) - 1])print("i: %d, loss: %10.10f, avgLoss: %10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))if os.path.exists("save.txt"):

os.remove("save.txt")print("saving...")

tf.train.Saver().save(sess, save_path=trainResultPath)

resultT= input('Would you like to save? (y/n)')if resultT == "y":print("saving...")

tf.train.Saver().save(sess, save_path=trainResultPath)

print(y.eval(session=sess, feed_dict={x: predictData}))

直接调用张量 y 的 eval 函数,并在命名参数 session 中传入 会话对象 sess,在命名参数 feed_dict 中传入需要预测的输入数据,就可以得到y 的计算结果

注意: 用神经网络计算,不需要传入目标值 yTrain ,也不需要在 sess.run 函数的结果数组中指定训练变量 trian

二、从文件中读取数据进行预测

假设在 程序执行目录下有此文件 :

importtensorflow as tfimportnumpy as npimportrandomimportosimportsys

ifRestartT=False

predictData=None

argt= sys.argv[1:]#同样,先获取命令行参数,从前忘后遍历,如果有 “-file=” ,会从该参数指定的文件中读取数据#读取数据后放进 predictData 中,但此时, predictData 会是一个二维数组,其中每一行代表文件中的一行数据#为了保持一致,我们把用命令行参数 "-predict=" 指定的预测输入数据也套上了一个方括号变成二维数组【虽然只有一行】#使用 predictData.shape[0] 获取二维数组的行数#因为数组的形态本身也是一个数组,其中下标为 0 的数字代表了它的行数

for v inargt:if v == "-restart":

ifRestartT=Trueif v.startswith("-file="):

tmpStr= v[len("-file="):]print(tmpStr)

predictData= np.loadtxt(tmpStr, dtype=np.float32, delimiter=",")

predictRowCount= predictData.shape[0]print("predictRowCount: %s" %predictRowCount)if v.startswith("-predict="):

tmpStr= v[len("-predict="):]

predictData= [np.fromstring(tmpStr, dtype=np.float32, sep=",")]print("predictData: %s" %predictData)

trainResultPath= "./save/idcard2"random.seed()

x=tf.placeholder(tf.float32)

yTrain=tf.placeholder(tf.float32)

w1= tf.Variable(tf.random_normal([4, 8], mean=0.5, stddev=0.1), dtype=tf.float32)

b1= tf.Variable(0, dtype=tf.float32)

xr= tf.reshape(x, [1, 4])

n1= tf.nn.tanh(tf.matmul(xr, w1) +b1)

w2= tf.Variable(tf.random_normal([8, 2], mean=0.5, stddev=0.1), dtype=tf.float32)

b2= tf.Variable(0, dtype=tf.float32)

n2= tf.matmul(n1, w2) +b2

y= tf.nn.softmax(tf.reshape(n2, [2]))

loss= tf.reduce_mean(tf.square(y -yTrain))

optimizer= tf.train.RMSPropOptimizer(0.01)

train=optimizer.minimize(loss)

sess=tf.Session()ififRestartT:print("force restart...")

sess.run(tf.global_variables_initializer())elif os.path.exists(trainResultPath + ".index"):print("loading: %s" %trainResultPath)

tf.train.Saver().restore(sess, save_path=trainResultPath)else:print("train result path not exists: %s" %trainResultPath)

sess.run(tf.global_variables_initializer())if predictData is not None:

for i in range(predictRowCount):

print(y.eval(session=sess, feed_dict={x: predictData[i]}))

sys.exit(0)#用一个循环,把 predictData 中的所有行的数据都输入神经网络中计算一边,最后输出结果

lossSum = 0.0

for i in range(500000):

xDataRandom= [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10), int(random.random() * 10)]if xDataRandom[2] % 2 ==0:

yTrainDataRandom= [0, 1]else:

yTrainDataRandom= [1, 0]

result= sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})

lossSum= lossSum + float(result[len(result) - 1])print("i: %d, loss: %10.10f, avgLoss: %10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))if os.path.exists("save.txt"):

os.remove("save.txt")print("saving...")

tf.train.Saver().save(sess, save_path=trainResultPath)

resultT= input('Would you like to save? (y/n)')if resultT == "y":print("saving...")

tf.train.Saver().save(sess, save_path=trainResultPath)

就可以 程序从 data2.txt 中获取了数据并转换成为一个二维数组,神经网络载入训练的过程数据后,根据当时的可变参数取值对每一行数据进行了预测

三、从任意字符串中读取数据进行预测

[[1,2,3,4],[2,4,6,8],[5,6,1,2],[7,9,0,3]]

上方是 python 中定义数组的写法,那么可以用 python 提到的 eval 函数把这个 字符串转换成为想要的数组类型。

假设有一个文本文件,data3.txt 且 有且仅有 上述字符串作为文件内容,编程实现,从文件中读取数据进行预测 :

importtensorflow as tfimportnumpy as npimportrandomimportosimportsys

ifRestartT=False

predictData=None

argt= sys.argv[1:]#如果制定了命令行参数 "-datafile=”,程序就从指定的文件中读取文件的全部内容#也就是把文件中的内容作为一个大字符串整个读进变量 fileStr 中#open 函数是 python 中用于打开指定位置文件的函数,会返回一个文件对象#调用该文件对象的 read 函数,就可以把文本文件的内容都读进来#再调用 eval 函数把这个字符串转换为 python 的数据对象#这里,python 会把它转换成一个 list 对象,直接用 numpy 的 array 函数就可以把它转换为数组

for v inargt:if v == "-restart":

ifRestartT=Trueif v.startswith("-file="):

tmpStr= v[len("-file="):]

predictData= np.loadtxt(tmpStr, dtype=np.float32, delimiter=",")

predictRowCount=predictData.shape[0]print("predictRowCount: %s" %predictRowCount)if v.startswith("-dataFile="):

tmpStr = v[len("-dataFile="):]

fileStr = open(tmpStr).read()

predictData = np.array(eval(fileStr))

predictRowCount = predictData.shape[0]print("predictRowCount: %s" % predictRowCount)if v.startswith("-predict="):

tmpStr= v[len("-predict="):]

predictData= [np.fromstring(tmpStr, dtype=np.float32, sep=",")]print("predictData: %s" %predictData)

trainResultPath= "./save/idcard2"random.seed()

x=tf.placeholder(tf.float32)

yTrain=tf.placeholder(tf.float32)

w1= tf.Variable(tf.random_normal([4, 8], mean=0.5, stddev=0.1), dtype=tf.float32)

b1= tf.Variable(0, dtype=tf.float32)

xr= tf.reshape(x, [1, 4])

n1= tf.nn.tanh(tf.matmul(xr, w1) +b1)

w2= tf.Variable(tf.random_normal([8, 2], mean=0.5, stddev=0.1), dtype=tf.float32)

b2= tf.Variable(0, dtype=tf.float32)

n2= tf.matmul(n1, w2) +b2

y= tf.nn.softmax(tf.reshape(n2, [2]))

loss= tf.reduce_mean(tf.square(y -yTrain))

optimizer= tf.train.RMSPropOptimizer(0.01)

train=optimizer.minimize(loss)

sess=tf.Session()ififRestartT:print("force restart...")

sess.run(tf.global_variables_initializer())elif os.path.exists(trainResultPath + ".index"):print("loading: %s" %trainResultPath)

tf.train.Saver().restore(sess, save_path=trainResultPath)else:print("train result path not exists: %s" %trainResultPath)

sess.run(tf.global_variables_initializer())if predictData is not None:

for i in range(predictRowCount):

print(y.eval(session=sess, feed_dict={x: predictData[i]}))

sys.exit(0)

lossSum= 0.0

for i in range(500000):

xDataRandom= [int(random.random() * 10), int(random.random() * 10), int(random.random() * 10), int(random.random() * 10)]if xDataRandom[2] % 2 ==0:

yTrainDataRandom= [0, 1]else:

yTrainDataRandom= [1, 0]

result= sess.run([train, x, yTrain, y, loss], feed_dict={x: xDataRandom, yTrain: yTrainDataRandom})

lossSum= lossSum + float(result[len(result) - 1])print("i: %d, loss: %10.10f, avgLoss: %10.10f" % (i, float(result[len(result) - 1]), lossSum / (i + 1)))if os.path.exists("save.txt"):

os.remove("save.txt")print("saving...")

tf.train.Saver().save(sess, save_path=trainResultPath)

resultT= input('Would you like to save? (y/n)')if resultT == "y":print("saving...")

tf.train.Saver().save(sess, save_path=trainResultPath)

执行程序:

当然,这里的格式也符合网络间传递数据的最常用的格式之一: JSON

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_42358261/article/details/111921685

智能推荐

markdown(md)文件的基本常用编辑语法_md文件语法_千寻~的博客-程序员秘密

.md即markdown文件的基本常用编写语法(图文并茂)原文:https://www.cnblogs.com/liugang-vip/p/6337580.html起因:因为现在的前端基本上都用上了前端构建工具,那就难免要写一些readme等等的说明性文件,但是这样的文件一般都是.md的文件,编写的语法自然跟其他格式的文件有所区别,置于为什么要用这种格式的文件,不要问我,我也不知道,大...

PHP 微信小程序 WebSocket MySQL Redis实现聊天功能_amateur12的博客-程序员秘密

1.Mysql 实现离线消息池。如果一个用户不在线,则其他用户发送给他的消息暂时存储在mysql。待该用户上线时,再从离线消息池取出发送。2.Redis 实现每个连接websocket的服务都唯一绑定一个用户。通过用户账号 = fd 存到redis中。微信小程序:websocket.php代码:<?php//创建WebSocket Server对象,监听0.0.0.0:9501端口$ws = new Swoole\WebSocket\Server('0.0.0...

自动光学检测系统(AOI)光学成像系统设计_钢铁男儿的博客-程序员秘密

光学成像系统设计在光学成像模块系统中光学系统的设计就是搭建一个可以快速准确采集具有高质量、高对比度、低噪声等的特征信息图像,使想要检测的特征部分尽量的突出、清晰以方便图像处理。光源及照明方式在AOI检测系统中光源是十分重要的组成部分,直接影响着输入数据质量的30%以及处理的速度,其中几个主要的作用分别是:增强目标物体与背景的对比度;提高待测物体边缘的清晰度;去除噪声消除阴影等。这些效果在图像处理的算法上体现的最为明显,好的光源可以简化软件处理步骤,优化系统,提高准确度在机器视觉的检测领域应...

蓝桥杯算法训练 大小写转换_水蛙菌的博客-程序员秘密

题目链接问题描述  编写一个程序,输入一个字符串(长度不超过20),然后把这个字符串内的每一个字符进行大小写变换,即将大写字母变成小写,小写字母变成大写,然后把这个新的字符串输出。  输入格式:输入一个字符串,而且这个字符串当中只包含英文字母,不包含其他类型的字符,也没有空格。  输出格式:输出经过转换后的字符串。输入输出样例样例输入AeDb样例输出aEdB代码:#include<bits/stdc++.h>using namespace std;string s;i

java复合型计算器_Java实现简单混合计算器_xr7m99的博客-程序员秘密

这个计算器并不是基于逆波兰实现的,而是通过简单的递归,一层一层地计算最终求得结果。具体的图形化界面可以参考我的另外一个篇博客:基于逆波兰表达式实现图形化混合计算器,这里我只是简单的介绍一下怎样求得算术表达式的结果,另外如果有需要可以加入另外那个博客中的Check方法,来检查表达式的合法性。计算表达式的算式如下所示(GitHub仓库地址):import java.math.BigDecimal;im...

zimbra修改记录分享_jiangyongyuan的博客-程序员秘密

晚上找到了jetty的日志,终于找到为何jetty没启动成功的原因。(原因是在做文字替换时,把类文件也给替换了,导致服务无法启动)zimbra系统已经重新跑起来,登陆页面暂时的解决方案:添加LogLogin.jsp页面。skyBook在页面中打开:http://172.17.1.14/zimbra/public/LogLogin.jsp ,便能使用邮件系统。今后通过传递用户名与密码,如:h...

随便推点

Yara、Snort和Sigma规则_yara规则_摔不死的笨鸟的博客-程序员秘密

Yara规则是基于二进制文件的静态HEX数据内容实现的扫描规则。简单点说,就是基于原始文件的内容数据扫描规则。Snort规则是基于IDS入侵检测系统,主要针对流量中数据包内容编写的扫描规则。SIGMA是一种通用的开放签名格式,允许以简单的方式描述SIEM系统中的相关日志事件。Yara规则很多人对yara规则是比较熟悉的。yara规则根据用途可以分为hunting yara规则和查杀yara规则两种。hunting作用的yara规则编写相对来说比较简单。除了命中目的样本外,还允许命中更多无关的黑样

Python 3.65 安装geopandas_ZHOU-LONG的博客-程序员秘密

geopandas在windows上安装极易出错,因为它依赖其他必要的库包,pip install xxx会有问题出现,个人建议: 方案一:直接下载对应Python版本的.whl文件,地址:https://www.lfd.uci.edu/~gohlke/pythonlibs/#pip。            方案二:安装conda,可以从清华大学镜像下载。地址是 https://mirro...

物联网常见通信协议RFID、NFC、Bluetooth、ZigBee等梳理_weixin_30902675的博客-程序员秘密

1 概述在上一篇文章《物联网常见通信协议与通讯协议梳理【上】-通讯协议》中,对物联网常用通信协议和通讯协议作了区分,并对通讯协议进行了分享;本文将对常用的通信协议进行剖析,重点面向市场上使用率较高的,且又不是诸如TCP/IP之类老生常谈的。2 近距离通信协议2.1 RFIDRFID的空中接口通信协议规范基本决定了RFID的工作类型,RFID读写器和相应类型R...

OI梗_oi圈是什么意思_百事可爱仔的博客-程序员秘密

OI梗本条目收录与OI(信息学竞赛,英语:Olympiad in Informatics)及其参赛选手圈子有关的流行文化。本条目仅收录常见梗,过于专业的用语不予收录,寻找专业用语请去OI Wiki等专业网站。目录1 OI用语1.1 %%%1.2 蒟蒻1.3 神犇/巨佬1.4 水题1.5 自动机1.6 卡常1.7 爆零2 OI典故/成句2.1 关于SPFA,它死了2.2 I AK IOI2.3 骗分导论2.4 o年OI一场空,xxxx见祖宗2.5 rp++2.6 我来NOI

银联Applepay_测试参数切换正式环境操作指南_银联csr证书_Bloodyer的博客-程序员秘密

一、  流程简介下载CSR文件(银联平台)将CSR提交至苹果(苹果开发者会员中心)查看证书使用流程(证书上传至银联平台)替换证书(证书pfx及验签证书cer)服务端更换请求交易接口地址客户端mode改为00一、  详细说明1.下载CSR文件(选做)    登陆商户服务平台(https://merchant.unionpay.com/),用户名及密码见

Git clone命令出现"fatal repository not found"错误_weixin_30535913的博客-程序员秘密

有时候使用Git命令"git clone [url]"将远程仓库中的代码爬取下来的时候系统会报错"fatal repository not found"。出现这个错误的一个可能的原因是本地已存储的git账号密码与爬取仓库所属的账号密码不同。解决办法:进入控制面板 >> 凭据管理器 >> Window凭据:选择普通凭据中保存的git账号信息进行编辑或者删除...

推荐文章

热门文章

相关标签