Tensorflow之MNIST CNN实现并保存、加载模型


Posted in Python onJune 17, 2020

本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下

废话不说,直接上代码

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
 
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
 
#download the data
mnist = keras.datasets.mnist
 
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
 
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 
train_images = train_images / 255.0
test_images = test_images / 255.0
 
def create_model():
 # It's necessary to give the input_shape,or it will fail when you load the model
 # The error will be like : You are trying to load the 4 layer models to the 0 layer 
 model = keras.Sequential([
   keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
   keras.layers.MaxPool2D(),
   keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
   keras.layers.MaxPool2D(),
   keras.layers.Flatten(),
   keras.layers.Dense(576, activation=tf.nn.relu),
   keras.layers.Dense(10, activation=tf.nn.softmax)
 ])
 
 model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
 
 return model
 
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
 
 
#train
model = create_model()                         
model.fit(train_images, train_labels, epochs=4)
 
#save the model
model.save('my_model.h5')
 
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)

模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试

#Load the model
 
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(), 
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
new_model.summary()
 
#Evaluate
 
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
 
#Predicte
 
mypath = 'C:\\pythonp\\testdir2'
 
def getimg(mypath):
  listdir = os.listdir(mypath)
  imgs = []
  for p in listdir:
    img = plt.imread(mypath+'\\'+p)
    # I save the picture that I draw myself under Windows, but the saved picture's
    # encode style is just opposite with the experiment data, so I transfer it with
    # this line. 
    img = np.abs(img/255-1)
    imgs.append(img[:,:,0])
  return np.array(imgs),len(imgs)
 
imgs = getimg(mypath)
 
test_images = np.reshape(imgs[0],[-1,28,28,1])
 
predictions = new_model.predict(test_images)
 
plt.figure()
 
for i in range(imgs[1]):
 c = np.argmax(predictions[i])
 plt.subplot(3,3,i+1)
 plt.xticks([])
 plt.yticks([])
 plt.imshow(test_images[i,:,:,0])
 plt.title(class_names[c])
plt.show()

测试结果

Tensorflow之MNIST CNN实现并保存、加载模型

自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 正则表达式 概述及常用字符
May 04 Python
Fiddler如何抓取手机APP数据包
Jan 22 Python
django1.8使用表单上传文件的实现方法
Nov 04 Python
基于python requests库中的代理实例讲解
May 07 Python
Django学习教程之静态文件的调用详解
May 08 Python
python批量修改图片大小的方法
Jul 24 Python
python3对拉勾数据进行可视化分析的方法详解
Apr 03 Python
关于Python字符串显示u...的解决方式
Mar 06 Python
BeautifulSoup获取指定class样式的div的实现
Dec 07 Python
虚拟环境及venv和virtualenv的区别说明
Feb 05 Python
python 中 .py文件 转 .pyd文件的操作
Mar 04 Python
Python使用openpyxl模块处理Excel文件
Jun 05 Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
You might like
PHP简单系统数据添加以及数据删除模块源文件下载
2008/06/07 PHP
php 时间计算问题小结
2009/01/04 PHP
Php图像处理类代码分享
2012/01/19 PHP
如何在Web页面上直接打开、编辑、创建Office文档
2007/03/12 Javascript
JS获取月的最后一天与JS得到一个月份最大天数的实例代码
2013/12/16 Javascript
PHP实现的各种中文编码转换类分享
2015/01/23 Javascript
在JavaScript中用getMinutes()方法返回指定的分时刻
2015/06/10 Javascript
AngularJS教程 ng-style 指令简单示例
2016/08/03 Javascript
bootstrap table操作技巧分享
2017/02/15 Javascript
Vue2学习笔记之请求数据交互vue-resource
2017/02/23 Javascript
详解webpack 多页面/入口支持&公共组件单独打包
2017/06/29 Javascript
详解Vue.js中引入图片路径的几种方式
2019/06/17 Javascript
vue+element-ui JYAdmin后台管理系统模板解析
2020/07/28 Javascript
Vue 组件的挂载与父子组件的传值实例
2020/09/02 Javascript
[02:27]刀塔重生降临
2015/10/14 DOTA
[00:17]天涯墨客一技能展示
2018/08/25 DOTA
TensorFlow 实战之实现卷积神经网络的实例讲解
2018/02/26 Python
python定向爬取淘宝商品价格
2018/02/27 Python
深入理解Python爬虫代理池服务
2018/02/28 Python
Python实现的建造者模式示例
2018/08/06 Python
使用python的pexpect模块,实现远程免密登录的示例
2019/02/14 Python
初次部署django+gunicorn+nginx的方法步骤
2019/09/11 Python
python/Matplotlib绘制复变函数图像教程
2019/11/21 Python
Django中使用MySQL5.5的教程
2019/12/18 Python
python Matplotlib基础--如何添加文本和标注
2021/01/26 Python
APM Monaco中国官网:来自摩纳哥珠宝品牌
2017/12/27 全球购物
美国购买新书和二手书网站:Better World Books
2018/10/31 全球购物
护士实习自我鉴定
2013/10/22 职场文书
《口技》教学反思
2014/02/21 职场文书
公司投资建议书
2014/05/16 职场文书
教师批评与自我批评发言稿
2014/10/15 职场文书
2014年维修工作总结
2014/11/22 职场文书
单位实习鉴定评语
2015/01/04 职场文书
2015年教研组工作总结
2015/05/04 职场文书
信仰纪录片观后感
2015/06/08 职场文书
go语言使用Casbin实现角色的权限控制
2021/06/26 Golang