Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的Flask框架中实现分页功能的教程
Apr 20 Python
使用Python内置的模块与函数进行不同进制的数的转换
Mar 12 Python
python实现TF-IDF算法解析
Jan 02 Python
使用python根据端口号关闭进程的方法
Nov 06 Python
python re库的正则表达式入门学习教程
Mar 08 Python
python hough变换检测直线的实现方法
Jul 12 Python
解决python replace函数替换无效问题
Jan 18 Python
Tensorflow 模型转换 .pb convert to .lite实例
Feb 12 Python
Window系统下Python如何安装OpenCV库
Mar 05 Python
django中嵌套的try-except实例
May 21 Python
keras使用Sequence类调用大规模数据集进行训练的实现
Jun 22 Python
浅谈Python响应式类库RxPy
Jun 14 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
ADODB结合SMARTY使用~超级强
2006/11/25 PHP
PHP date函数参数详解
2006/11/27 PHP
PHP 源代码压缩小工具
2009/12/22 PHP
redis 队列操作的例子(php)
2012/04/12 PHP
使用PHPStorm+XDebug搭建单步调试环境
2017/11/19 PHP
yii2中关于加密解密的那些事儿
2018/06/12 PHP
PHP获取ttf格式文件字体名的方法示例
2019/03/06 PHP
jquery中eq和get的区别与使用方法
2011/04/14 Javascript
JavaScript实现的石头剪刀布游戏源码分享
2014/08/22 Javascript
js读取csv文件并使用json显示出来
2015/01/09 Javascript
jquery实现全屏滚动
2015/12/28 Javascript
jQuery unbind()方法实例详解
2016/01/19 Javascript
Javascript的比较汇总
2016/07/25 Javascript
基于js实现checkbox批量选中操作
2016/11/22 Javascript
vue vuex vue-rouert后台项目——权限路由(适合初学)
2017/12/29 Javascript
vscode中vue-cli项目es-lint的配置方法
2018/07/30 Javascript
js使用ajax传值给后台,后台返回字符串处理方法
2018/08/08 Javascript
微信小程序自定义toast弹窗效果的实现代码
2018/11/15 Javascript
JS监听滚动和id自动定位滚动
2018/12/18 Javascript
JS数组Object.keys()方法的使用示例
2019/06/05 Javascript
jquery html添加元素/删除元素操作实例详解
2020/05/20 jQuery
javascript实现移动端轮播图
2020/12/09 Javascript
python网络编程实例简析
2014/09/26 Python
Python实现KNN邻近算法
2021/01/28 Python
解决TensorFlow模型恢复报错的问题
2020/02/06 Python
python+selenium爬取微博热搜存入Mysql的实现方法
2021/01/27 Python
Python爬虫+tkinter界面实现历史天气查询的思路详解
2021/02/22 Python
数据库测试通常都包括哪些方面
2015/11/30 面试题
求职简历自荐信范文
2013/10/21 职场文书
父母对孩子的寄语
2014/04/09 职场文书
《最大的麦穗》教学反思
2014/04/17 职场文书
2015年母亲节活动策划方案
2015/05/04 职场文书
2016年庆祝六一儿童节活动总结
2016/04/06 职场文书
python基础入门之普通操作与函数(三)
2021/06/13 Python
CSS 实现角标效果的完整代码
2022/06/28 HTML / CSS
Go中使用gjson来操作JSON数据的实现
2022/08/14 Golang