Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的__slots__使用示例
Feb 26 Python
Django中的“惰性翻译”方法的相关使用
Jul 27 Python
Python基于FTP模块实现ftp文件上传操作示例
Apr 23 Python
pip命令无法使用的解决方法
Jun 12 Python
浅谈pandas用groupby后对层级索引levels的处理方法
Nov 06 Python
Python反爬虫技术之防止IP地址被封杀的讲解
Jan 09 Python
Python基础之文件读取的讲解
Feb 16 Python
Python实现根据日期获取当天凌晨时间戳的方法示例
Apr 09 Python
python的pygal模块绘制反正切函数图像方法
Jul 16 Python
python面向对象 反射原理解析
Aug 12 Python
python3 selenium自动化测试 强大的CSS定位方法
Aug 23 Python
python同步两个文件夹下的内容
Aug 29 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
php把session写入数据库示例
2014/02/26 PHP
ThinkPHP的I方法使用详解
2014/06/18 PHP
php使用百度ping服务代码实例
2014/06/19 PHP
js+FSO遍历文件夹下文件并显示
2007/03/07 Javascript
javascript alert乱码的解决方法
2013/11/05 Javascript
Js+Jq获取URL参数的集中方法示例代码
2014/05/20 Javascript
基于jquery实现的文字向上跑动类似跑马灯的效果
2014/06/17 Javascript
javascript使用for循环批量注册的事件不能正确获取索引值的解决方法
2014/12/20 Javascript
基于zepto.js实现仿手机QQ空间的大图查看组件ImageView.js详解
2015/03/05 Javascript
JavaScript 实现完美兼容多浏览器的复制功能代码
2015/04/28 Javascript
jQuery原型属性和原型方法详解
2015/07/07 Javascript
浅谈$(document)和$(window)的区别
2015/07/15 Javascript
微信小程序 刷新上拉下拉不会断详细介绍
2017/05/11 Javascript
ExtJs的Ext.Ajax.request实现waitMsg等待提示效果
2017/06/14 Javascript
el-select 下拉框多选实现全选的实现
2019/08/02 Javascript
jQuery实现评论模块
2020/08/19 jQuery
[03:06]V社市场总监Dota2项目负责人Erik专访:希望更多中国玩家加入DOTA2
2014/07/11 DOTA
[01:58]DOTA2上海特级锦标赛现场采访:RTZ这个ID到底好不好
2016/03/25 DOTA
[02:07]2018DOTA2亚洲邀请赛主赛事第三日五佳镜头 fy极限反杀
2018/04/06 DOTA
[49:56]VG vs Optic 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
Python实现对象转换为xml的方法示例
2017/06/08 Python
python 3调用百度OCR API实现剪贴板文字识别
2018/09/04 Python
Python绘制并保存指定大小图像的方法
2019/01/10 Python
Python生成一个迭代器的实操方法
2019/06/18 Python
python中的subprocess.Popen()使用详解
2019/12/25 Python
使用Python操作ArangoDB的方法步骤
2020/02/02 Python
python实现拼图小游戏
2020/02/22 Python
jupyter notebook的安装与使用详解
2020/05/18 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
2020/05/25 Python
python 基于UDP协议套接字通信的实现
2021/01/22 Python
zooplus波兰:在线宠物店
2019/07/21 全球购物
股东协议书范本
2014/04/14 职场文书
代理词怎么写
2015/05/25 职场文书
2019年大学生职业生涯规划书最新范文
2019/03/25 职场文书
Python 中的 copy()和deepcopy()
2021/11/07 Python
WINDOWS下安装mysql 8.x 的方法图文教程
2022/04/19 MySQL