Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
PHP魔术方法__ISSET、__UNSET使用实例
Nov 25 Python
解决python2.7用pip安装包时出现错误的问题
Jan 23 Python
Python基于正则表达式实现文件内容替换的方法
Aug 30 Python
Python内建模块struct实例详解
Feb 02 Python
Python切片操作实例分析
Mar 16 Python
Django压缩静态文件的实现方法详析
Aug 26 Python
padas 生成excel 增加sheet表的实例
Dec 11 Python
Python 迭代,for...in遍历,迭代原理与应用示例
Oct 12 Python
pytorch实现seq2seq时对loss进行mask的方式
Feb 18 Python
如何导出python安装的所有模块名称和版本号到文件中
Jun 05 Python
解决Keras的自定义lambda层去reshape张量时model保存出错问题
Jul 01 Python
python爬虫使用正则爬取网站的实现
Aug 03 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
PHP mb_convert_encoding文字编码的转换函数介绍
2011/11/10 PHP
php微信开发之自定义菜单实现
2016/11/18 PHP
PHP实现基本留言板功能原理与步骤详解
2020/03/26 PHP
js 屏蔽鼠标右键脚本附破解方法
2009/12/03 Javascript
屏蔽F1~F12的快捷键的js函数
2010/05/06 Javascript
浏览器打开层自动缓慢展开收缩实例代码
2013/07/04 Javascript
Js日期选择自动填充到输入框(界面漂亮兼容火狐)
2013/08/02 Javascript
巧用局部变量提升javascript性能
2014/02/24 Javascript
Jquery实现点击按钮,连续地向textarea中添加值的实例代码
2014/03/08 Javascript
javascript检测flash插件是否被禁用的方法
2016/01/14 Javascript
原生javascript实现的一个简单动画效果
2016/03/30 Javascript
js添加事件的通用方法推荐
2016/05/15 Javascript
JavaScript实现页面跳转的方式汇总
2016/05/16 Javascript
Angularjs 制作购物车功能实例代码
2016/09/14 Javascript
javascript汉字拼音互转的简单实例
2016/10/09 Javascript
Angularjs2不同组件间的通信实例代码
2017/05/06 Javascript
js根据需要计算数组中重复出现某个元素的个数
2019/01/18 Javascript
layer iframe 设置关闭按钮的方法
2019/09/12 Javascript
Vue解析剪切板图片并实现发送功能
2020/02/04 Javascript
微信小程序实现手指拖动选项排序
2020/04/22 Javascript
[01:22:42]2014 DOTA2华西杯精英邀请赛 5 24 DK VS LGD
2014/05/26 DOTA
Python内置函数 next的具体使用方法
2017/11/24 Python
Python+matplotlib+numpy实现在不同平面的二维条形图
2018/01/02 Python
使用 Django Highcharts 实现数据可视化过程解析
2019/07/31 Python
python config文件的读写操作示例
2019/09/27 Python
python读取word 中指定位置的表格及表格数据
2019/10/23 Python
python关闭占用端口方式
2019/12/17 Python
Python 实现简单的客户端认证
2020/07/29 Python
python+django+selenium搭建简易自动化测试
2020/08/19 Python
python中如何使用虚拟环境
2020/10/14 Python
物业管理员岗位职责范文
2013/11/25 职场文书
幼儿园教育教学反思
2014/01/31 职场文书
军人违纪检讨书
2014/02/04 职场文书
安全承诺书格式
2014/05/21 职场文书
2014年科研工作总结
2014/12/03 职场文书
初中生思想道德自我评价
2015/03/09 职场文书