Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 字符串格式化代码
Mar 17 Python
python 布尔操作实现代码
Mar 23 Python
Python字符串切片操作知识详解
Mar 28 Python
Python实现将罗马数字转换成普通阿拉伯数字的方法
Apr 19 Python
Python爬虫之正则表达式基本用法实例分析
Aug 08 Python
使用 Python 实现微信群友统计器的思路详解
Sep 26 Python
Django1.11自带分页器paginator的使用方法
Oct 31 Python
Django import export实现数据库导入导出方式
Apr 03 Python
tensorflow常用函数API介绍
Apr 19 Python
浅谈python量化 双均线策略(金叉死叉)
Jun 03 Python
通过Python实现Payload分离免杀过程详解
Jul 13 Python
Python基于callable函数检测对象是否可被调用
Oct 16 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
用PHP查询搜索引擎排名位置的代码
2010/01/05 PHP
php继承的一个应用
2011/09/06 PHP
基于JQuery+PHP编写砸金蛋中奖程序
2015/09/08 PHP
鼠标拖动实现DIV排序示例代码
2013/10/14 Javascript
解决Jquery鼠标经过不停滑动的问题
2014/03/03 Javascript
javascript中声明函数的方法及调用函数的返回值
2014/07/22 Javascript
jQuery实现倒计时按钮功能代码分享
2014/09/03 Javascript
Javascript数组Array方法解读
2016/03/13 Javascript
EXT中单击button按钮grid添加一行(光标位置可设置)的实例代码
2016/06/02 Javascript
JS实现显示带倒影的图片横排居中放大展示特效实例【测试可用】
2016/08/23 Javascript
JavaScript中return用法示例
2016/11/29 Javascript
angular中实现控制器之间传递参数的方式
2017/04/24 Javascript
Js实现京东无延迟菜单效果实例(demo)
2017/06/02 Javascript
Javascript中JSON数据分组优化实践及JS操作JSON总结
2017/12/22 Javascript
vue+axios 前端实现的常用拦截的代码示例
2018/08/23 Javascript
微信小程序swiper左右扩展各显示一半代码实例
2019/12/05 Javascript
Vue设置长时间未操作登录自动到期返回登录页
2020/01/22 Javascript
python实现在无须过多援引的情况下创建字典的方法
2014/09/25 Python
Python装饰器原理与用法分析
2018/04/30 Python
pandas 把数据写入txt文件每行固定写入一定数量的值方法
2018/12/28 Python
梅尔频率倒谱系数(mfcc)及Python实现
2019/06/18 Python
pytorch 修改预训练model实例
2020/01/18 Python
TensorFlow获取加载模型中的全部张量名称代码
2020/02/11 Python
联想墨西哥官方网站:Lenovo墨西哥
2016/08/17 全球购物
Rentalcars.com中国:世界上最大的在线汽车租赁服务
2019/08/22 全球购物
4s店总经理岗位职责
2013/12/31 职场文书
二年级班级文化建设方案
2014/05/10 职场文书
员工教育培训协议书
2014/09/27 职场文书
群众路线教育实践活动学习笔记
2014/11/05 职场文书
企业2014年度工作总结
2014/12/10 职场文书
班主任经验交流材料
2014/12/16 职场文书
工程资料员岗位职责
2015/04/13 职场文书
红色电影观后感
2015/06/18 职场文书
教师远程培训心得体会
2016/01/09 职场文书
入门学习Go的基本语法
2021/07/07 Golang
MySQL创建管理子分区
2022/04/13 MySQL