Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Mysql自动备份脚本
Jul 14 Python
Python备份Mysql脚本
Aug 11 Python
python实现apahce网站日志分析示例
Apr 02 Python
python中lambda与def用法对比实例分析
Apr 30 Python
Python性能提升之延迟初始化
Dec 04 Python
Django基础之Model操作步骤(介绍)
May 27 Python
python opencv 图像尺寸变换方法
Apr 02 Python
python 利用pandas将arff文件转csv文件的方法
Feb 12 Python
python字符串切割:str.split()与re.split()的对比分析
Jul 16 Python
linux环境下Django的安装配置详解
Jul 22 Python
通过实例学习Python Excel操作
Jan 06 Python
python Gabor滤波器讲解
Oct 26 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
一个改进的UBB类
2006/10/09 PHP
PHP新手上路(十)
2006/10/09 PHP
IIS下配置Php+Mysql+zend的图文教程
2006/12/08 PHP
PHP中使用SimpleXML检查XML文件结构实例
2015/01/07 PHP
PHP Laravel中的Trait使用方法
2019/01/20 PHP
常见的5个PHP编码小陋习以及优化实例讲解
2021/02/27 PHP
判断浏览器的javascript版本的代码
2010/09/03 Javascript
浅谈js中的闭包
2015/03/16 Javascript
jQuery实现限制textarea文本框输入字符数量的方法
2015/05/28 Javascript
jQuery实现的简单折叠菜单(折叠面板)效果代码
2015/09/16 Javascript
关于Vue单页面骨架屏实践记录
2017/12/13 Javascript
jQuery实现动态加载select下拉列表项功能示例
2018/05/31 jQuery
深入浅析js原型链和vue构造函数
2018/10/25 Javascript
JavaScript前端开发时数值运算的小技巧
2020/07/28 Javascript
全局安装 Vue cli3 和 继续使用 Vue-cli2.x操作
2020/09/08 Javascript
Python中实现两个字典(dict)合并的方法
2014/09/23 Python
全面了解Python的getattr(),setattr(),delattr(),hasattr()
2016/06/14 Python
Python 两个列表的差集、并集和交集实现代码
2016/09/21 Python
Python结巴中文分词工具使用过程中遇到的问题及解决方法
2017/04/15 Python
python发送邮件实例分享
2017/07/28 Python
windows下python之mysqldb模块安装方法
2017/09/07 Python
取numpy数组的某几行某几列方法
2018/04/03 Python
python 解决动态的定义变量名,并给其赋值的方法(大数据处理)
2018/11/10 Python
python3利用Axes3D库画3D模型图
2020/03/25 Python
全球销量第一生发产品:Viviscal
2017/12/21 全球购物
Feelunique中文官网:欧洲最大化妆品零售电商
2020/07/10 全球购物
教师节促销方案
2014/03/22 职场文书
社区义诊活动总结
2014/04/30 职场文书
学校2014重阳节活动策划方案
2014/09/16 职场文书
党员民主生活会个人整改措施材料
2014/09/16 职场文书
国庆节促销广告语2014
2014/09/19 职场文书
八项规定个人对照检查材料思想汇报
2014/09/25 职场文书
女性健康知识讲座主持词
2015/07/04 职场文书
python自动计算图像数据集的RGB均值
2021/06/18 Python
Redis高可用集群redis-cluster详解
2022/03/20 Redis
Python使用PyYAML库读写yaml文件的方法
2022/04/06 Python