Tensorflow实现在训练好的模型上进行测试


Posted in Python onJanuary 20, 2020

Tensorflow可以使用训练好的模型对新的数据进行测试,有两种方法:第一种方法是调用模型和训练在同一个py文件中,中情况比较简单;第二种是训练过程和调用模型过程分别在两个py文件中。本文将讲解第二种方法。

模型的保存

tensorflow提供可保存训练模型的接口,使用起来也不是很难,直接上代码讲解:

#网络结构
w1 = tf.Variable(tf.truncated_normal([in_units, h1_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h1_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#损失函数与优化函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)

saver = tf.train.Saver()
with tf.Session() as sess: 
    sess.run(init) 
    saver.save(sess,"save/model.ckpt") 
    train_step.run({x: train_x, y_: train_y})

以上代码就完成了模型的保存,值得注意的是下面这行代码

tf.add_to_collection('network-output', y)

这行代码保存了神经网络的输出,这个在后面使用导入模型过程中起到关键作用。

模型的导入

模型训练并保存后就可以导入来评估模型在测试集上的表现,网上很多文章只用简单的四则运算来做例子,让人看的头大。还是先上代码:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('./model.ckpt.meta')
  saver.restore(sess, './model.ckpt')# .data文件
  pred = tf.get_collection('network-output')[0]

  graph = tf.get_default_graph()
  x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

  y = sess.run(pred, feed_dict={x: test_x, y_: test_y})

讲解一下关键的代码,首先是pred = tf.get_collection('pred_network')[0],这行代码获得训练过程中网络输出的“接口”,简单理解就是,通过tf.get_collection() 这个方法获取了整个网络结构。获得网络结构后我们就需要喂它对应的数据y = sess.run(pred, feed_dict={x: test_x, y_: test_y}) 在训练过程中我们的输入是

x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')

因此导入模型后所需的输入也要与之对应可使用以下代码获得:

x = graph.get_operation_by_name('x').outputs[0]
  y_ = graph.get_operation_by_name('y_').outputs[0]

使用模型的最后一步就是输入测试集,然后按照训练好的网络进行评估

sess.run(pred, feed_dict={x: test_x, y_: test_y})

理解下这行代码,sess.run() 的函数原型为

run(fetches, feed_dict=None, options=None, run_metadata=None)

Tensorflow对 feed_dict 执行fetches操作,因此在导入模型后的运算就是,按照训练的网络计算测试输入的数据。

以上这篇Tensorflow实现在训练好的模型上进行测试就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django URL传递参数的方法总结
Aug 28 Python
python中redis的安装和使用
Dec 04 Python
python自定义异常实例详解
Jul 11 Python
Django实现分页功能
Jul 02 Python
基于Python实现定时自动给微信好友发送天气预报
Oct 25 Python
Django框架模板语言实例小结【变量,标签,过滤器,继承,html转义】
May 23 Python
自定义django admin model表单提交的例子
Aug 23 Python
TensorFlow tensor的拼接实例
Jan 19 Python
python有序查找算法 二分法实例解析
Feb 18 Python
Python-jenkins模块之folder相关操作介绍
May 12 Python
Python3 pywin32模块安装的详细步骤
May 26 Python
Python爬虫之爬取淘女郎照片示例详解
Jul 28 Python
Python线程条件变量Condition原理解析
Jan 20 #Python
tensorflow tf.train.batch之数据批量读取方式
Jan 20 #Python
Python list运算操作代码实例解析
Jan 20 #Python
Python模块future用法原理详解
Jan 20 #Python
使用Tensorflow将自己的数据分割成batch训练实例
Jan 20 #Python
Python JSON编解码方式原理详解
Jan 20 #Python
从训练好的tensorflow模型中打印训练变量实例
Jan 20 #Python
You might like
smarty的保留变量问题
2008/10/23 PHP
PHP游戏编程25个脚本代码
2011/02/08 PHP
CI框架中zip类应用示例
2014/06/17 PHP
详解php设置session(过期、失效、有效期)
2015/11/12 PHP
PHP面向对象程序设计之构造方法和析构方法详解
2019/06/13 PHP
基于JQuery的密码强度验证代码
2010/03/01 Javascript
JavaScript限定复选框的选择个数示例代码
2013/08/25 Javascript
用js+iframe形成页面的一种遮罩效果的具体实现
2013/12/31 Javascript
JS计算网页停留时间代码
2014/04/28 Javascript
js中键盘事件实例简析
2015/01/10 Javascript
JS选取DOM元素的简单方法
2016/07/08 Javascript
理解 javascript 中的函数表达式与函数声明
2017/07/07 Javascript
VUE DEMO之模拟登录个人中心页面之间数据传值实例
2019/10/31 Javascript
js实现整体缩放页面适配移动端
2020/03/31 Javascript
基于canvasJS在PHP中制作动态图表
2020/05/30 Javascript
js实现表格数据搜索
2020/08/09 Javascript
python实现获取客户机上指定文件并传输到服务器的方法
2015/03/16 Python
Python中str.format()详解
2017/03/12 Python
python编程使用selenium模拟登陆淘宝实例代码
2018/01/25 Python
Python实现扣除个人税后的工资计算器示例
2018/03/26 Python
详解python while 函数及while和for的区别
2018/09/07 Python
Python递归求出列表(包括列表中的子列表)的最大值实例
2020/02/27 Python
使paramiko库执行命令时在给定的时间强制退出功能的实现
2021/03/03 Python
pytorch Dataset,DataLoader产生自定义的训练数据案例
2021/03/03 Python
HTML5添加禁止缩放功能
2017/11/03 HTML / CSS
Intersport西班牙:在线体育商店
2019/11/06 全球购物
元旦红领巾广播稿
2014/02/19 职场文书
“九一八事变纪念日”国旗下讲话稿
2014/09/14 职场文书
公司内部升职自荐信
2015/03/27 职场文书
公司奖励通知
2015/04/21 职场文书
CSS3点击按钮圆形进度打钩效果的实现代码
2021/03/30 HTML / CSS
redis通过6379端口无法连接服务器(redis-server.exe闪退)
2021/05/08 Redis
Python-OpenCV教程之图像的位运算详解
2021/06/21 Python
python中的class_static的@classmethod的巧妙用法
2021/06/22 Python
python实现手机推送 代码也就10行左右
2022/04/12 Python
微软官方消息,在 2023 年 4 月 11 日之后微软将不再为 Office 2013 和 Skype for Business 2015 提供安全更新
2022/04/21 数码科技