Tensorflow实现在训练好的模型上进行测试


Posted in Python onJanuary 20, 2020

Tensorflow可以使用训练好的模型对新的数据进行测试,有两种方法:第一种方法是调用模型和训练在同一个py文件中,中情况比较简单;第二种是训练过程和调用模型过程分别在两个py文件中。本文将讲解第二种方法。

模型的保存

tensorflow提供可保存训练模型的接口,使用起来也不是很难,直接上代码讲解:

#网络结构
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#损失函数与优化函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess: 
    sess.run(init) 
    saver.save(sess,"save/model.ckpt") 
    train_step.run({x: train_x, y_: train_y})

以上代码就完成了模型的保存,值得注意的是下面这行代码

tf.add_to_collection('network-output', y)

这行代码保存了神经网络的输出,这个在后面使用导入模型过程中起到关键作用。

模型的导入

模型训练并保存后就可以导入来评估模型在测试集上的表现,网上很多文章只用简单的四则运算来做例子,让人看的头大。还是先上代码:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('./model.ckpt.meta')
  saver.restore(sess, './model.ckpt')# .data文件
  pred = tf.get_collection('network-output')[0]

  graph = tf.get_default_graph()
  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

  y = sess.run(pred, feed_dict={x: test_x, y_: test_y})

讲解一下关键的代码,首先是pred = tf.get_collection('pred_network')[0],这行代码获得训练过程中网络输出的“接口”,简单理解就是,通过tf.get_collection() 这个方法获取了整个网络结构。获得网络结构后我们就需要喂它对应的数据y = sess.run(pred, feed_dict={x: test_x, y_: test_y}) 在训练过程中我们的输入是

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')

因此导入模型后所需的输入也要与之对应可使用以下代码获得:

x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

使用模型的最后一步就是输入测试集,然后按照训练好的网络进行评估

sess.run(pred, feed_dict={x: test_x, y_: test_y})

理解下这行代码,sess.run() 的函数原型为

run(fetches, feed_dict=None, options=None, run_metadata=None)

Tensorflow对 feed_dict 执行fetches操作,因此在导入模型后的运算就是,按照训练的网络计算测试输入的数据。

以上这篇Tensorflow实现在训练好的模型上进行测试就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中string模块各属性以及函数的用法介绍
May 30 Python
在Python程序和Flask框架中使用SQLAlchemy的教程
Jun 06 Python
Python使用re模块正则提取字符串中括号内的内容示例
Jun 01 Python
python3实现域名查询和whois查询功能
Jun 21 Python
在Python文件中指定Python解释器的方法
Feb 18 Python
python实践项目之监控当前联网状态详情
May 23 Python
python3字符串操作总结
Jul 24 Python
Python流程控制 while循环实现解析
Sep 02 Python
python多进程并行代码实例
Sep 30 Python
Django 创建后台,配置sqlite3教程
Nov 18 Python
python小白学习包管理器pip安装
Jun 09 Python
python更新数据库中某个字段的数据(方法详解)
Nov 18 Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
Python list运算操作代码实例解析
Jan 20 #Python
Python模块future用法原理详解
Jan 20 #Python
使用Tensorflow将自己的数据分割成batch训练实例
Jan 20 #Python
Python JSON编解码方式原理详解
Jan 20 #Python
从训练好的tensorflow模型中打印训练变量实例
Jan 20 #Python
You might like
全国FM电台频率大全 - 5 内蒙古自治区
2020/03/11 无线电
怎样在PHP中通过ADO调用Asscess数据库和COM程序
2006/10/09 PHP
PHP调用MySQL的存储过程的实现代码
2008/08/12 PHP
php删除字符串末尾子字符,删除开始字符,删除两端字符(实现代码)
2013/06/27 PHP
php文字水印和php图片水印实现代码(二种加水印方法)
2013/12/25 PHP
php合并数组中相同元素的方法
2014/11/13 PHP
php快速查找数据库中恶意代码的方法
2015/04/01 PHP
PHP编程中尝试程序并发的几种方式总结
2016/03/21 PHP
Joomla语言翻译类Jtext用法分析
2016/05/05 PHP
PHP如何实现阿里云短信sdk灵活应用在项目中的方法
2019/06/14 PHP
Laravel框架源码解析之入口文件原理分析
2020/05/14 PHP
基于node.js的快速开发透明代理
2010/12/25 Javascript
javascript实现促销倒计时+fixed固定在底部
2013/09/18 Javascript
模拟一个类似百度google的模糊搜索下拉列表
2014/04/15 Javascript
jQuery.extend()、jQuery.fn.extend()扩展方法示例详解
2014/05/08 Javascript
jQuery+ajax实现鼠标单击修改内容的思路
2014/06/29 Javascript
使用javascript实现判断当前浏览器
2015/04/14 Javascript
javaScript中with函数用法实例分析
2015/06/08 Javascript
跟我学习javascript创建对象(类)的8种方法
2015/11/20 Javascript
js仿QQ邮箱收件人选择与搜索功能
2017/02/10 Javascript
全面解析Node.js 8 重要功能和修复
2017/06/02 Javascript
微信小程序图片宽100%显示并且不变形
2017/06/21 Javascript
详解RequireJs官方使用教程
2017/10/31 Javascript
vue2.0 + element UI 中 el-table 数据导出Excel的方法
2018/03/02 Javascript
解决Vue的文本编辑器 vue-quill-editor 小图标样式排布错乱问题
2020/08/03 Javascript
python中cPickle类使用方法详解
2018/08/27 Python
git查看、创建、删除、本地、远程分支方法详解
2020/02/18 Python
python实现文件+参数发送request的实例代码
2021/01/05 Python
墨西哥运动服饰和鞋网上商店:Netshoes墨西哥
2016/07/28 全球购物
Tarte Cosmetics官网:美国最受欢迎的化妆品公司之一
2017/08/24 全球购物
经典c++面试题二
2015/08/14 面试题
优秀毕业生推荐信
2013/11/02 职场文书
优秀实习自我鉴定
2013/12/04 职场文书
少年雷锋观后感
2015/06/10 职场文书
Go语言-为什么返回值为接口类型,却返回结构体
2021/04/24 Golang
人民币符号
2022/02/17 杂记