Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python list 与 NumPy.ndarry 切片之间的对比
Jul 24 Python
Python实现的生成格雷码功能示例
Jan 24 Python
pandas将DataFrame的列变成行索引的方法
Apr 10 Python
Python设计模式之原型模式实例详解
Jan 18 Python
python 动态迁移solr数据过程解析
Sep 04 Python
sklearn+python:线性回归案例
Feb 24 Python
jupyter notebook 多环境conda kernel配置方式
Apr 10 Python
python 操作mysql数据中fetchone()和fetchall()方式
May 15 Python
python 实时调取摄像头的示例代码
Nov 25 Python
python 发送邮件的示例代码(Python2/3都可以直接使用)
Dec 03 Python
python statsmodel的使用
Dec 21 Python
Python 中的单分派泛函数你真的了解吗
Jun 22 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
一个php作的文本留言本的例子(三)
2006/10/09 PHP
不用数据库的多用户文件自由上传投票系统(2)
2006/10/09 PHP
解析PHP中empty is_null和isset的测试
2013/06/29 PHP
PHP实现微信图片上传到服务器的方法示例
2017/06/29 PHP
基于 Swoole 的微信扫码登录功能实现代码
2018/01/15 PHP
PHP内置函数生成随机数实例
2019/01/18 PHP
PHP实现一个按钮点击上传多个图片操作示例
2020/01/23 PHP
js函数使用技巧之 setTimeout(function(){},0)
2009/02/09 Javascript
javascript作用域容易记错的两个地方分析
2012/06/22 Javascript
js实现点击图片改变页面背景图的方法
2015/02/28 Javascript
javascript针对不确定函数的执行方法
2015/12/16 Javascript
javascript实现倒计时跳转页面
2016/01/17 Javascript
Node.js websocket使用socket.io库实现实时聊天室
2017/02/20 Javascript
详解vue-cli项目中用json-sever搭建mock服务器
2017/11/02 Javascript
发布一款npm包帮助理解npm的使用
2019/01/03 Javascript
mpvue微信小程序多列选择器用法之省份城市选择的实现
2019/03/07 Javascript
Layui选项卡制作历史浏览记录的方法
2019/09/28 Javascript
JS数组的高级使用方法示例小结
2020/03/14 Javascript
利用js canvas实现五子棋游戏
2020/10/11 Javascript
在antd Table中插入可编辑的单元格实例
2020/10/28 Javascript
一文秒懂nodejs中的异步编程
2021/01/28 NodeJs
精确查找PHP WEBSHELL木马的方法(1)
2011/04/12 Python
Python实现微信公众平台自定义菜单实例
2015/03/20 Python
Python处理Excel文件实例代码
2017/06/20 Python
深入理解Python3 内置函数大全
2017/11/23 Python
Python机器学习库scikit-learn安装与基本使用教程
2018/06/25 Python
使用EduBlock轻松学习Python编程
2018/10/08 Python
如何用Python徒手写线性回归
2021/01/25 Python
护理学毕业生自荐信
2013/10/02 职场文书
求职信模板怎么做
2014/01/26 职场文书
工艺员岗位职责
2014/02/11 职场文书
房产代理公证处委托书
2014/04/04 职场文书
践行三严三实心得体会
2014/10/13 职场文书
《家世》读后感:看家训的力量
2019/12/30 职场文书
Python基于百度AI实现抓取表情包
2021/06/27 Python
win10电脑老是死机怎么办?win10系统老是死机的解决方法
2022/08/05 数码科技