Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3 入门教程 简单但比较不错
Nov 29 Python
Python自动化构建工具scons使用入门笔记
Mar 10 Python
Python对象转JSON字符串的方法
Apr 27 Python
python 根据正则表达式提取指定的内容实例详解
Dec 04 Python
python 上下文管理器使用方法小结
Oct 10 Python
Python爬虫工程师面试问题总结
Mar 22 Python
Python 数据处理库 pandas进阶教程
Apr 21 Python
Django中使用Celery的教程详解
Aug 24 Python
聊聊python里如何用Borg pattern实现的单例模式
Jun 06 Python
python语言线程标准库threading.local解读总结
Nov 10 Python
pycharm通过anaconda安装pyqt5的教程
Mar 24 Python
Python3爬虫中关于Ajax分析方法的总结
Jul 10 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
转生史莱姆:萌王第一次撸串开心到飞起,哥布塔撸串却神似界王神
2018/11/30 日漫
关于PHPDocument 代码注释规范的总结
2013/06/25 PHP
谈谈PHP连接Access数据库的注意事项
2016/08/12 PHP
yii2.0整合阿里云oss删除单个文件的方法
2017/09/19 PHP
Yii框架学习笔记之session与cookie简单操作示例
2019/04/30 PHP
PHP实现二维数组(或多维数组)转换成一维数组的常见方法总结
2019/12/04 PHP
Prototype使用指南之base.js
2007/01/10 Javascript
JavaScript DOM 学习第七章 表单的扩展
2010/02/19 Javascript
如何制作浮动广告 JavaScript制作浮动广告代码
2012/12/30 Javascript
JavaScript 处理Iframe自适应高度(同或不同域名下)
2013/03/29 Javascript
什么是Node.js?Node.js详细介绍
2014/06/01 Javascript
Jquery实现$.fn.extend和$.extend函数
2016/04/14 Javascript
jquery实现全选、全不选以及单选功能
2017/03/23 jQuery
JavaScript瀑布流布局实现代码
2017/05/06 Javascript
vue项目持久化存储数据的实现代码
2018/10/01 Javascript
Vue看了就会的8个小技巧
2021/01/21 Vue.js
[54:53]2014 DOTA2国际邀请赛中国区预选赛 LGD-GAMING VS CIS 第二场
2014/05/23 DOTA
[10:49]2014国际邀请赛 叨叨刀塔第二期为真正的电竞喝彩
2014/07/21 DOTA
实现python版本的按任意键继续/退出
2016/09/26 Python
Django自定义分页与bootstrap分页结合
2021/02/22 Python
Python3.4 tkinter,PIL图片转换
2018/06/21 Python
Django urls.py重构及参数传递详解
2019/07/23 Python
python创建学生管理系统
2019/11/22 Python
python脚本实现mp4中的音频提取并保存在原目录
2020/02/27 Python
Python object类中的特殊方法代码讲解
2020/03/06 Python
详解python tkinter 图片插入问题
2020/09/03 Python
使用HTML5做的导航条详细步骤
2020/10/19 HTML / CSS
Nice Kicks网上商店:ShopNiceKicks.com
2018/12/25 全球购物
波兰家居和花园家具专家:4Home
2019/05/26 全球购物
英国Lookfantastic中文网站:护肤品美妆美发购物(英国直邮)
2020/04/27 全球购物
初中校园之声广播稿
2014/01/15 职场文书
法院四风对照检查材料思想汇报
2014/10/06 职场文书
公司岗位说明书
2015/10/08 职场文书
2016入党积极分子党课培训心得体会
2016/01/06 职场文书
Win11 引入 Windows 365 云操作系统,适应疫情期间混合办公模式:启动时直接登录、模
2022/04/06 数码科技
Spring Boot 的创建和运行示例代码详解
2022/07/23 Java/Android