Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置函数的用法实例教程
Sep 08 Python
Python使用gensim计算文档相似性
Apr 10 Python
python使用正则表达式匹配字符串开头并打印示例
Jan 11 Python
Python模块结构与布局操作方法实例分析
Jul 24 Python
Python实现拷贝/删除文件夹的方法详解
Aug 29 Python
Python使用os.listdir()和os.walk()获取文件路径与文件下所有目录的方法
Apr 01 Python
解决Mac下使用python的坑
Aug 13 Python
树莓派3 搭建 django 服务器的实例
Aug 29 Python
python实现网站微信登录的示例代码
Sep 18 Python
wxPython实现分隔窗口
Nov 19 Python
python-numpy-指数分布实例详解
Dec 07 Python
Elasticsearch 批量操作
Apr 19 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
PHP.MVC的模板标签系统(三)
2006/09/05 PHP
基于PHP字符串的比较函数strcmp()与strcasecmp()的使用详解
2013/05/15 PHP
解析ajax事件的调用顺序
2013/06/17 PHP
php+js iframe实现上传头像界面无跳转
2014/04/29 PHP
laravel框架模型中非静态方法也能静态调用的原理分析
2019/11/23 PHP
PHP设计模式之建造者模式(Builder)原理与用法案例详解
2019/12/12 PHP
客户端静态页面玩分页
2006/06/26 Javascript
jQuery总体架构的理解分析
2011/03/07 Javascript
DIV外区域Click后关闭DIV的实现代码
2011/12/21 Javascript
jquery 中多条件选择器,相对选择器,层次选择器的区别
2012/07/03 Javascript
jQuery简单实现仿京东分类导航层效果
2016/06/07 Javascript
浅析Bootstrap验证控件的使用
2016/06/23 Javascript
JS DOMReady事件的六种实现方法总结
2016/11/23 Javascript
js学习总结之DOM2兼容处理顺序问题的解决方法
2017/07/27 Javascript
layui的table单击行勾选checkbox功能方法
2018/08/14 Javascript
使用微信SDK自定义分享的方法
2019/07/03 Javascript
keep-Alive搭配vue-router实现缓存页面效果的示例代码
2020/06/24 Javascript
详解JavaScript中的数据类型,以及检测数据类型的方法
2020/09/17 Javascript
举例详解Python中threading模块的几个常用方法
2015/06/18 Python
python中os和sys模块的区别与常用方法总结
2017/11/14 Python
基于python 爬虫爬到含空格的url的处理方法
2018/05/11 Python
基于wxPython的GUI实现输入对话框(1)
2019/02/27 Python
Python数学形态学实例分析
2019/09/06 Python
Pycharm IDE的安装和使用教程详解
2020/04/30 Python
15款Python编辑器的优缺点,别再问我“选什么编辑器”啦
2020/10/19 Python
Craghoppers德国官网:户外和旅行服装
2020/02/14 全球购物
金鑫耀Java笔试题
2014/09/06 面试题
服装公司总经理岗位职责
2013/11/30 职场文书
手机被没收检讨书
2014/02/22 职场文书
小学美术兴趣小组活动总结
2014/07/07 职场文书
学校安全工作汇报材料
2014/08/16 职场文书
领导批评与自我批评范文
2014/10/16 职场文书
2014年职称评定工作总结
2014/11/26 职场文书
回复函范文
2015/07/14 职场文书
2019年员工旷工保证书!
2019/06/28 职场文书
奖学金申请书(范文)
2019/08/14 职场文书