Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3使用scrapy生成csv文件代码示例
Dec 28 Python
Python异常处理操作实例详解
May 10 Python
python自动截取需要区域,进行图像识别的方法
May 17 Python
详解如何从TensorFlow的mnist数据集导出手写体数字图片
Aug 05 Python
Django关于admin的使用技巧和知识点
Feb 10 Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 Python
django教程如何自学
Jul 31 Python
python3:excel操作之读取数据并返回字典 + 写入的案例
Sep 01 Python
Python3+Appium安装及Appium模拟微信登录方法详解
Feb 16 Python
解决Pyinstaller打包软件失败的一个坑
Mar 04 Python
Python实现文本文件拆分写入到多个文本文件的方法
Apr 18 Python
Python OpenCV 图像平移的实现示例
Jun 04 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
77A一级收信机修理记
2021/03/02 无线电
坏狼php学习 计数器实例代码
2008/06/15 PHP
discuz论坛 用户登录 后台程序代码
2008/11/27 PHP
PHP中实现中文字符进制转换原理分析
2011/12/06 PHP
遍历指定目录,并存储目录内所有文件属性信息的php代码
2016/10/28 PHP
微信小程序 消息推送php服务器验证实例详解
2017/03/30 PHP
PHP实现批量清空删除指定文件夹所有内容的方法
2017/05/30 PHP
PHP实现的贪婪算法实例
2017/10/17 PHP
php简单检测404页面的方法示例
2019/08/23 PHP
JavaScript Event学习第九章 鼠标事件
2010/02/08 Javascript
Node.js:Windows7下搭建的Node.js服务(来玩玩服务器端的javascript吧,这可不是前端js插件)
2011/06/27 Javascript
JavaScript 函数参数是传值(byVal)还是传址(byRef) 分享
2013/07/02 Javascript
js写的方法实现上传图片之后查看大图
2014/03/05 Javascript
javascript变量声明实例分析
2015/04/25 Javascript
全面解析Bootstrap中form、navbar的使用方法
2016/05/30 Javascript
AngularJS入门教程之多视图切换用法示例
2016/11/02 Javascript
JS利用正则表达式实现简单的密码强弱判断实例
2017/06/16 Javascript
vue异步axios获取的数据渲染到页面的方法
2018/08/09 Javascript
Node.js操作系统OS模块用法分析
2019/01/04 Javascript
[04:55]完美世界副总裁蔡玮:DOTA2的自由、公平与信任
2013/12/18 DOTA
[01:45]亚洲邀请赛互动指南虚拟物品介绍
2015/01/30 DOTA
[42:11]TNC vs Pain 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/20 DOTA
[01:29:17]RNG vs Liquid 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.23
2019/09/05 DOTA
布同自制Python函数帮助查询小工具
2011/03/13 Python
django开发教程之利用缓存文件进行页面缓存的方法
2017/11/10 Python
使用python编写监听端
2018/04/12 Python
Windows系统下PhantomJS的安装和基本用法
2018/10/21 Python
python实现图片转字符小工具
2019/04/30 Python
Python smtp邮件发送模块用法教程
2020/06/15 Python
Html5实现如何在两个div元素之间拖放图像
2013/03/29 HTML / CSS
LG西班牙网上商店:Tienda LG Online Es
2019/07/30 全球购物
实习单位推荐信范文
2013/11/27 职场文书
幼儿园中秋节活动方案
2014/02/06 职场文书
学生社团文化节开幕式主持词
2014/03/28 职场文书
2014年幼儿园工作总结
2014/11/10 职场文书
承诺书范本
2015/01/21 职场文书