Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pycharm 使用心得(四)显示行号
Jun 05 Python
Python字符串详细介绍
May 09 Python
Python 3中的yield from语法详解
Jan 18 Python
Python中实现变量赋值传递时的引用和拷贝方法
Apr 29 Python
Python(Django)项目与Apache的管理交互的方法
May 16 Python
Python实例方法、类方法、静态方法的区别与作用详解
Mar 25 Python
Django 1.10以上版本 url 配置注意事项详解
Aug 05 Python
python 使用while写猜年龄小游戏过程解析
Oct 07 Python
Python实现检测文件的MD5值来查找重复文件案例
Mar 12 Python
Python识别处理照片中的条形码
Nov 16 Python
解决pip安装tensorflow中出现的no module named tensorflow.python 问题方法
Feb 20 Python
python单向链表实例详解
May 25 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
用Apache反向代理设置对外的WWW和文件服务器
2006/10/09 PHP
坏狼php学习 计数器实例代码
2008/06/15 PHP
php分页函数示例代码分享
2014/02/24 PHP
PHP图像裁剪缩略裁切类源码及使用方法
2016/01/07 PHP
PHP检测数据类型的几种方法(总结)
2017/03/04 PHP
PHP基于自定义类随机生成姓名的方法示例
2017/08/05 PHP
PHP7下协程的实现方法详解
2017/12/17 PHP
PHP生成随机数的方法总结
2018/03/01 PHP
JavaScript 捕获窗口关闭事件
2009/07/26 Javascript
JavaScript 数组循环引起的思考
2010/01/01 Javascript
基于jquery的大众点评,分类导航实现代码
2011/08/23 Javascript
jquery实现邮箱自动补全功能示例分享
2014/02/17 Javascript
node.js中watch机制详解
2014/11/17 Javascript
jQuery四种选择器使用及示例
2016/06/05 Javascript
用jQuery.ajaxSetup实现对请求和响应数据的过滤
2016/12/20 Javascript
使用jsonp实现跨域获取数据实例讲解
2016/12/25 Javascript
vux uploader 图片上传组件的安装使用方法
2018/05/15 Javascript
解决vue单页路由跳转后scrollTop的问题
2018/09/03 Javascript
详解easyui 切换主题皮肤
2019/04/04 Javascript
vue中格式化时间过滤器代码实例
2019/04/17 Javascript
js 判断当前时间是否处于某个一个时间段内
2019/09/19 Javascript
vue使用exif获取图片经纬度的示例代码
2020/12/11 Vue.js
[02:40]DOTA2超级联赛专访430 从小就爱玩对抗性游戏
2013/06/18 DOTA
[00:30]塑造者的传承礼包-戴泽“暗影之焰”套装展示视频
2014/04/04 DOTA
[01:01:31]2018DOTA2亚洲邀请赛3月29日小组赛B组 Mineski VS paiN
2018/03/30 DOTA
Python操作列表的常用方法分享
2014/02/13 Python
Python中MYSQLdb出现乱码的解决方法
2014/10/11 Python
python腾讯语音合成实现过程解析
2019/08/01 Python
pytorch多GPU并行运算的实现
2019/09/27 Python
如何在Python对Excel进行读取
2020/06/04 Python
用Python实现童年贪吃蛇小游戏功能的实例代码
2020/12/07 Python
美国伴娘礼服商店:Evening Collective
2019/10/07 全球购物
双创工作实施方案
2014/03/26 职场文书
促销活动总结报告
2014/04/26 职场文书
浅谈Python从全局与局部变量到装饰器的相关知识
2021/06/21 Python
使用springMVC所需要的pom配置
2021/09/15 Java/Android