Tensorflow 之 MNIST CNN实现并保存、加载模型_cnn model.save-程序员宅基地

技术标签: Tensorflow  机器学习  人工智能  

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os

#download the data
mnist = keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

train_images = train_images / 255.0
test_images = test_images / 255.0

def create_model():
  # It's necessary to give the input_shape,or it will fail when you load the model
  # The error will be like : You are trying to load the 4 layer models to the 0 layer 
  model = keras.Sequential([
      keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
      keras.layers.MaxPool2D(),
      keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
      keras.layers.MaxPool2D(),
      keras.layers.Flatten(),
      keras.layers.Dense(576, activation=tf.nn.relu),
      keras.layers.Dense(10, activation=tf.nn.softmax)
  ])

  model.compile(optimizer=tf.train.AdamOptimizer(), 
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
  
  return model

#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])


#train
model = create_model()                                                 
model.fit(train_images, train_labels, epochs=4)

#save the model
model.save('my_model.h5')

#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model

new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
new_model.summary()

#Evaluate

# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)

#Predicte

mypath = 'C:\\pythonp\\testdir2'

def getimg(mypath):
    listdir = os.listdir(mypath)
    imgs = []
    for p in listdir:
        img = plt.imread(mypath+'\\'+p)
        # I save the picture that I draw myself under Windows, but the saved picture's
        # encode style is just opposite with the experiment data, so I transfer it with
        # this line. 
        img = np.abs(img/255-1)
        imgs.append(img[:,:,0])
    return np.array(imgs),len(imgs)

imgs = getimg(mypath)

test_images = np.reshape(imgs[0],[-1,28,28,1])

predictions = new_model.predict(test_images)

plt.figure()

for i in range(imgs[1]):
  c = np.argmax(predictions[i])
  plt.subplot(3,3,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.imshow(test_images[i,:,:,0])
  plt.title(class_names[c])
plt.show()

测试结果

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/uflswe/article/details/85212325

智能推荐

HTML数据输入模板,文本和HTML 模板-程序员宅基地

文章浏览阅读533次。在Go语言中,我们使用`template`包来进行模板处理,使用类似`Parse`、`ParseFile`、`Execute`等方法从文件或者字符串加载模板,然后执行类似上面图片展示的模板的merge操作。请看下面的例子:~~~func handler(w http.ResponseWriter, r *http.Request) {t := template.New("some template..._输入框中固定内容模板

于Ubuntu20.04安装Vivado19.2出现安装过程卡在generating installed device list一步的解决方法_ubuntu vivado generating installed device list-程序员宅基地

文章浏览阅读7.2k次,点赞7次,收藏23次。于Ubuntu20.04安装Vivado19.2出现安装过程卡在generating installed device list一步的解决方法一、ncurses库未安装二、配置LD_LIBRARY_PATH环境变量三、我进不去的未知网页查阅了Xilinx官方论坛,找到了该篇文章,里面有三种解决方法,搬运至此一、ncurses库未安装我遇到的就是这个问题,ncurses库未安装。我起初是不知道遇到的问题的,浏览中在多个贴子中提到了这个问题,于是就尝试了一下,解决了我的问题。方法如下:打开终端:Ctr_ubuntu vivado generating installed device list

Python用turtle库作图的时候,如何将画一下画完,不用显示画的过程?_turtle瞬间画完-程序员宅基地

文章浏览阅读1.3w次,点赞14次,收藏19次。Python用turtle库作图的时候,有时候等待太慢了,如何将画一下画完,不用显示画的过程?答案是可以的,使用:turtle.tracer(False)。 语句就可以实现,自己试试吧!..._turtle瞬间画完

史上最全的SpringBoot学习教程!会不断更新_springbot学习-程序员宅基地

文章浏览阅读9.9w次,点赞220次,收藏1.9k次。把写过的SpringBoot系列的文章全部整理在此,方便大家学习查看!_springbot学习

java 单元测试覆盖率调研_java junit test覆盖率 n/a-程序员宅基地

文章浏览阅读969次。单元测试覆盖率调研_java junit test覆盖率 n/a

vue 获取select的值及动态添加option_el-select 动态添加option-程序员宅基地

文章浏览阅读8.2k次。前端代码<div id="app"> <select class="form-control con1" @change="changeProduct($event)"> <option value="" disabled selected>请选择</option> <option v-for="i..._el-select 动态添加option

随便推点

C语言(循环)-程序员宅基地

文章浏览阅读1k次,点赞27次,收藏24次。如果表达式2为假(==0)则循环结束,如果表达式2为真(==1) 就执行下面的语句,执行完后,再执行表达式3(调整循环变量)之后判断表达式2的真假决定是否再进行下一次循环。就是说如果if后面的表达式为真那么就会执行后面的语句,while后面的表达式为真的话也会执行语句,当执行完语句后会返回再一次判断表达式的真假。它们的区别是while语句是可以进行循环的操作的。2. 假设要判断i是否为素数,需要拿2~i-1之间的数字去试除i,需要产⽣2~i-1之间的数字,也可以使⽤ 循环解决。那么还是打印数字1~10。

差分进化算法在物流运输中的优化解决方案-程序员宅基地

文章浏览阅读762次,点赞23次,收藏17次。1.背景介绍物流运输业是现代社会的重要组成部分,它涉及到各种各样的商品和物资在不同地点之间的运输。随着经济的发展和人口的增长,物流运输业面临着越来越多的挑战,如交通拥堵、环境污染、运输成本高昂等问题。因此,在物流运输中,优化问题的解决对于提高运输效率、降低成本、提高服务质量等方面具有重要意义。在物流运输中,优化问题通常可以用数学模型来表示,例如:运输成本最小化:找到一种运输方式,使得...

两个activity之间透明过渡效果和经验_activity切换渐透明动画-程序员宅基地

文章浏览阅读2.3k次。来看下效果图: 大致效果解释: 1. 当用户点击登录时logo下滑一定距离 2. 下滑后旋转90时 变化图标 3. 继续旋转90度 4. 然后移动到左上角 透明度渐变到上个activity 最后销毁当前activity术语登录界面我们 称为 A (本质是activity) 如下图 返回界面我们 称为 B (本质是activity) 如下图 大致思路让A界面的A_activity切换渐透明动画

双馈风力发电机DFIG滑模控制策略的MATLAB Simulink仿真研究 采用非线性控制滑模控制策略_双馈风机并网仿真matlab仿真-程序员宅基地

文章浏览阅读408次。未来,我们将进一步研究滑模控制策略在DFIG风力发电系统中的实际应用,为提高风力发电系统的效率和稳定性提供新的思路和方法。在仿真过程中,我们通过比较传统双PI环控制策略和本文提出的双环控制策略的仿真结果,发现本文提出的双环控制策略具有更好的功率跟随性能和稳定性。此外,通过调整仿真模型的参数,我们还对滑模控制器和PI调节器的性能进行了评估和优化。为了验证本文提出的双馈风力发电机滑模控制策略的正确性和优越性,我们搭建了MATLAB Simulink仿真模型进行仿真研究。三、PI调节器与滑模控制器。_双馈风机并网仿真matlab仿真

深入理解 Golang: 网络编程_golang 网络-程序员宅基地

文章浏览阅读885次。在 Go 中,内部采用结合阻塞模型和多路复用的方法。在这里就不再是线程操作 Socket,而是 Goroutine 协程。每个协程关心一个 Socket 连接:1.在底层使用操作系统的多路复用 IO;2.在协程层次使用阻塞模型。3.阻塞协程时,休眠协程。_golang 网络

linux中vim怎么分栏,Vim+Taglist+AutoComplPop之代码目录分栏信息和自动补全提示(Ubuntu环境)...-程序员宅基地

文章浏览阅读429次。一步:首先在Ubuntu环境中安装ctags: sudo apt-get install ctags第二部:解压:1.$unzip -d taglist taglist_xx.zip2.$cd taglist复制到指定路径下:1.$cp doc/taglist.txt /usr/share/vim/vim73/doc/2.$cp plugin/taglist.vim /usr/share/vim..._linux下autocomplpop

推荐文章

热门文章

相关标签