Tensorflow实现在训练好的模型上进行测试


Posted in Python onJanuary 20, 2020

Tensorflow可以使用训练好的模型对新的数据进行测试,有两种方法:第一种方法是调用模型和训练在同一个py文件中,中情况比较简单;第二种是训练过程和调用模型过程分别在两个py文件中。本文将讲解第二种方法。

模型的保存

tensorflow提供可保存训练模型的接口,使用起来也不是很难,直接上代码讲解:

#网络结构
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#损失函数与优化函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess: 
    sess.run(init) 
    saver.save(sess,"save/model.ckpt") 
    train_step.run({x: train_x, y_: train_y})

以上代码就完成了模型的保存,值得注意的是下面这行代码

tf.add_to_collection('network-output', y)

这行代码保存了神经网络的输出,这个在后面使用导入模型过程中起到关键作用。

模型的导入

模型训练并保存后就可以导入来评估模型在测试集上的表现,网上很多文章只用简单的四则运算来做例子,让人看的头大。还是先上代码:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('./model.ckpt.meta')
  saver.restore(sess, './model.ckpt')# .data文件
  pred = tf.get_collection('network-output')[0]

  graph = tf.get_default_graph()
  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

  y = sess.run(pred, feed_dict={x: test_x, y_: test_y})

讲解一下关键的代码,首先是pred = tf.get_collection('pred_network')[0],这行代码获得训练过程中网络输出的“接口”,简单理解就是,通过tf.get_collection() 这个方法获取了整个网络结构。获得网络结构后我们就需要喂它对应的数据y = sess.run(pred, feed_dict={x: test_x, y_: test_y}) 在训练过程中我们的输入是

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')

因此导入模型后所需的输入也要与之对应可使用以下代码获得:

x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

使用模型的最后一步就是输入测试集,然后按照训练好的网络进行评估

sess.run(pred, feed_dict={x: test_x, y_: test_y})

理解下这行代码,sess.run() 的函数原型为

run(fetches, feed_dict=None, options=None, run_metadata=None)

Tensorflow对 feed_dict 执行fetches操作,因此在导入模型后的运算就是,按照训练的网络计算测试输入的数据。

以上这篇Tensorflow实现在训练好的模型上进行测试就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python with的用法
Aug 22 Python
python变量不能以数字打头详解
Jul 06 Python
Python学习笔记之if语句的使用示例
Oct 23 Python
python中判断文件编码的chardet(实例讲解)
Dec 21 Python
python实现日常记账本小程序
Mar 10 Python
Pytorch中accuracy和loss的计算知识点总结
Sep 10 Python
python3 自动打印出最新版本执行的mysql2redis实例
Apr 09 Python
python opencv实现简易画图板
Aug 27 Python
Python eval函数介绍及用法
Nov 09 Python
基于PyTorch中view的用法说明
Mar 03 Python
python垃圾回收机制原理分析
Apr 13 Python
Python中生成随机数据安全性、多功能性、用途和速度方面进行比较
Apr 14 Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
Python list运算操作代码实例解析
Jan 20 #Python
Python模块future用法原理详解
Jan 20 #Python
使用Tensorflow将自己的数据分割成batch训练实例
Jan 20 #Python
Python JSON编解码方式原理详解
Jan 20 #Python
从训练好的tensorflow模型中打印训练变量实例
Jan 20 #Python
You might like
PHP实现图片裁剪、添加水印效果代码
2014/10/01 PHP
php实现扫描二维码根据浏览器类型访问不同下载地址
2014/10/15 PHP
php解析http获取的json字符串变量总是空白null
2015/03/02 PHP
通过修改配置真正解决php文件上传大小限制问题(nginx+php)
2015/09/23 PHP
php慢查询日志和错误日志使用详解
2021/02/27 PHP
JS取文本框中最小值的简单实例
2013/11/29 Javascript
代码触发js事件(click、change)示例应用
2013/12/13 Javascript
js创建对象的区别示例介绍
2014/07/24 Javascript
jQuery的position()方法详解
2015/07/19 Javascript
JavaScript 不支持 indexof 该如何解决
2016/03/30 Javascript
纯JS代码实现气泡效果
2016/05/04 Javascript
JS实现获取当前URL和来源URL的方法
2016/08/24 Javascript
angular route中使用resolve在uglify压缩后问题解决
2016/09/21 Javascript
BootStrap 实现各种样式的进度条效果
2016/12/07 Javascript
jquery判断页面网址是否有效的两种方法
2016/12/11 Javascript
protractor的安装与基本使用教程
2017/07/07 Javascript
Webstorm2016使用技巧(SVN插件使用)
2018/10/29 Javascript
微信小程序官方动态自定义底部tabBar的例子
2019/09/04 Javascript
解决ele ui 表格表头太长问题的实现
2019/11/13 Javascript
[01:14:55]EG vs Spirit Supermajor 败者组 BO3 第三场 6.4
2018/06/05 DOTA
Python与Java间Socket通信实例代码
2017/03/06 Python
Python 专题三 字符串的基础知识
2017/03/19 Python
CentOS7.3编译安装Python3.6.2的方法
2018/01/22 Python
Python使用cx_Oracle模块操作Oracle数据库详解
2018/05/07 Python
解决tensorflow测试模型时NotFoundError错误的问题
2018/07/27 Python
Python调用C++,通过Pybind11制作Python接口
2018/10/16 Python
对Python 多线程统计所有csv文件的行数方法详解
2019/02/12 Python
python pymysql库的常用操作
2020/10/16 Python
CSS3制作3D立方体loading特效
2020/11/09 HTML / CSS
高级护理实习生自荐信
2013/09/28 职场文书
运动会表扬稿大全
2014/01/16 职场文书
结婚老公保证书
2015/02/26 职场文书
2015年毕业生自荐信范文
2015/03/24 职场文书
公司财务部岗位职责
2015/04/14 职场文书
幼儿园万圣节活动总结
2015/05/05 职场文书
Redis如何实现验证码发送 以及限制每日发送次数
2022/04/18 Redis