Tensorflow实现在训练好的模型上进行测试


Posted in Python onJanuary 20, 2020

Tensorflow可以使用训练好的模型对新的数据进行测试,有两种方法:第一种方法是调用模型和训练在同一个py文件中,中情况比较简单;第二种是训练过程和调用模型过程分别在两个py文件中。本文将讲解第二种方法。

模型的保存

tensorflow提供可保存训练模型的接口,使用起来也不是很难,直接上代码讲解:

#网络结构
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#损失函数与优化函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess: 
    sess.run(init) 
    saver.save(sess,"save/model.ckpt") 
    train_step.run({x: train_x, y_: train_y})

以上代码就完成了模型的保存,值得注意的是下面这行代码

tf.add_to_collection('network-output', y)

这行代码保存了神经网络的输出,这个在后面使用导入模型过程中起到关键作用。

模型的导入

模型训练并保存后就可以导入来评估模型在测试集上的表现,网上很多文章只用简单的四则运算来做例子,让人看的头大。还是先上代码:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('./model.ckpt.meta')
  saver.restore(sess, './model.ckpt')# .data文件
  pred = tf.get_collection('network-output')[0]

  graph = tf.get_default_graph()
  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

  y = sess.run(pred, feed_dict={x: test_x, y_: test_y})

讲解一下关键的代码,首先是pred = tf.get_collection('pred_network')[0],这行代码获得训练过程中网络输出的“接口”,简单理解就是,通过tf.get_collection() 这个方法获取了整个网络结构。获得网络结构后我们就需要喂它对应的数据y = sess.run(pred, feed_dict={x: test_x, y_: test_y}) 在训练过程中我们的输入是

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')

因此导入模型后所需的输入也要与之对应可使用以下代码获得:

x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

使用模型的最后一步就是输入测试集,然后按照训练好的网络进行评估

sess.run(pred, feed_dict={x: test_x, y_: test_y})

理解下这行代码,sess.run() 的函数原型为

run(fetches, feed_dict=None, options=None, run_metadata=None)

Tensorflow对 feed_dict 执行fetches操作,因此在导入模型后的运算就是,按照训练的网络计算测试输入的数据。

以上这篇Tensorflow实现在训练好的模型上进行测试就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常用的日期时间处理方法示例
Feb 08 Python
python复制文件的方法实例详解
May 22 Python
Python中常见的数据类型小结
Aug 29 Python
详谈python http长连接客户端
Jun 12 Python
python初学之用户登录的实现过程(实例讲解)
Dec 23 Python
python 整数越界问题详解
Jun 27 Python
python 3.7.4 安装 opencv的教程
Oct 10 Python
Python基础类继承重写实现原理解析
Apr 03 Python
django使用channels实现通信的示例
Oct 19 Python
几款好用的python工具库(小结)
Oct 20 Python
教你用python实现一个无界面的小型图书管理系统
May 21 Python
Python 内置函数速查表一览
Jun 02 Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
Python list运算操作代码实例解析
Jan 20 #Python
Python模块future用法原理详解
Jan 20 #Python
使用Tensorflow将自己的数据分割成batch训练实例
Jan 20 #Python
Python JSON编解码方式原理详解
Jan 20 #Python
从训练好的tensorflow模型中打印训练变量实例
Jan 20 #Python
You might like
第五节 克隆 [5]
2006/10/09 PHP
PHP与MySQL交互使用详解
2006/10/09 PHP
PHP安装攻略:常见问题解答(三)
2006/10/09 PHP
PHP安装BCMath扩展的方法
2019/02/13 PHP
半角全角相互转换的js函数
2009/10/16 Javascript
jquery.AutoComplete.js中文修正版(支持firefox)
2010/04/09 Javascript
jquery.validate分组验证代码
2011/03/17 Javascript
editable.js 基于jquery的表格的编辑插件
2011/10/24 Javascript
jquery计算鼠标和指定元素之间距离的方法
2015/06/26 Javascript
d3.js实现简单的网络拓扑图实例代码
2016/11/06 Javascript
JavaScript轻松创建级联函数的方法示例
2017/02/10 Javascript
Sublime Text新建.vue模板并高亮(图文教程)
2017/10/26 Javascript
jQuery实现碰到边缘反弹的动画效果
2018/02/24 jQuery
JS计算斐波拉切代码实例
2019/09/12 Javascript
Vue实现简单的跑马灯
2020/05/25 Javascript
Vue和React有哪些区别
2020/09/12 Javascript
解决VUE 在IE下出现ReferenceError: Promise未定义的问题
2020/11/07 Javascript
[04:16]DOTA2英雄梦之声_第09期_斧王
2014/06/21 DOTA
[02:22]《新闻直播间》2017年08月14日
2017/08/15 DOTA
ssh批量登录并执行命令的python实现代码
2012/05/25 Python
python3爬虫怎样构建请求header
2018/12/23 Python
基于django和dropzone.js实现上传文件
2020/11/24 Python
90后毕业生的求职信范文
2013/09/21 职场文书
优秀员工自荐书范文
2013/12/08 职场文书
节能环保口号
2014/06/12 职场文书
个人股份合作协议书
2014/10/24 职场文书
个人股份转让协议书范本
2015/01/28 职场文书
大学军训通讯稿
2015/07/18 职场文书
就业指导讲座心得体会
2016/01/15 职场文书
《清澈的湖水》教学反思
2016/02/17 职场文书
幼儿园2016年感恩节活动总结
2016/04/01 职场文书
2016年大学光棍节活动总结
2016/04/05 职场文书
MySQL 如何设计统计数据表
2021/06/15 MySQL
浅析Python实现DFA算法
2021/06/26 Python
python not运算符的实例用法
2021/06/30 Python
JavaScript 数组去重详解
2021/09/15 Javascript