tensorflow 20:搭网络,导出模型,运行模型的实例


Posted in Python onMay 26, 2020

概述

以前自己都利用别人搭好的工程,修改过来用,很少把模型搭建、导出模型、加载模型运行走一遍,搞了一遍才知道这个事情也不是那么简单的。

搭建模型和导出模型

参考《TensorFlow固化模型》,导出固化的模型有两种方式.

方式1:导出pb图结构和ckpt文件,然后用 freeze_graph 工具冻结生成一个pb(包含结构和参数)

在我的代码里测试了生成pb图结构和ckpt文件,但是没接着往下走,感觉有点麻烦。我用的是第二种方法。

注意我这里只在最后保存了一次ckpt,实际应该在训练中每隔一段时间就保存一次的。

saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # 保存pb和ckpt
 print('save pb file and ckpt file')
 tf.train.write_graph(sess.graph_def, graph_location, "graph.pb",as_text=False)
 checkpoint_path = os.path.join(graph_location, "model.ckpt")
 saver.save(sess, checkpoint_path, global_step=max_step)

方式2:convert_variables_to_constants

我实际使用的就是这种方法。

看名字也知道,就是把变量转化为常量保存,这样就可以愉快的加载使用了。

注意这里需要指明保存的输出节点,我的输出节点为'out/fc2'(我猜测会根据输出节点的依赖推断哪些部分是训练用到的,推理时用不到)。关于输出节点的名字是有规律的,其中out是一个name_scope名字,fc2是op节点的名字。

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 # 固化模型
 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

上述代码会在训练后把训练好的计算图和参数保存到frozen_graph.pb文件。后续就可以用这个模型来测试图片了。

方式2的完整训练和保存模型代码

主要看main函数就行。另外注意deepnn函数最后节点的名字。

"""A deep MNIST classifier using convolutional layers.

See extensive documentation at
https://www.tensorflow.org/get_started/mnist/pros
"""
# Disable linter warnings to maintain consistency with tutorial.
# pylint: disable=invalid-name
# pylint: disable=g-bad-import-order

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile
import os

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants

import tensorflow as tf
FLAGS = None

def deepnn(x):
 """deepnn builds the graph for a deep net for classifying digits.

 Args:
 x: an input tensor with the dimensions (N_examples, 784), where 784 is the
 number of pixels in a standard MNIST image.

 Returns:
 A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
 equal to the logits of classifying the digit into one of 10 classes (the
 digits 0-9). keep_prob is a scalar placeholder for the probability of
 dropout.
 """
 # Reshape to use within a convolutional neural net.
 # Last dimension is for "features" - there is only one here, since images are
 # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
 with tf.name_scope('reshape'):
 x_image = tf.reshape(x, [-1, 28, 28, 1])

 # First convolutional layer - maps one grayscale image to 32 feature maps.
 with tf.name_scope('conv1'):
 W_conv1 = weight_variable([5, 5, 1, 32])
 b_conv1 = bias_variable([32])
 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)

 # Pooling layer - downsamples by 2X.
 with tf.name_scope('pool1'):
 h_pool1 = max_pool_2x2(h_conv1)

 # Second convolutional layer -- maps 32 feature maps to 64.
 with tf.name_scope('conv2'):
 W_conv2 = weight_variable([5, 5, 32, 64])
 b_conv2 = bias_variable([64])
 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

 # Second pooling layer.
 with tf.name_scope('pool2'):
 h_pool2 = max_pool_2x2(h_conv2)

 # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
 # is down to 7x7x64 feature maps -- maps this to 1024 features.
 with tf.name_scope('fc1'):
 W_fc1 = weight_variable([7 * 7 * 64, 1024])
 b_fc1 = bias_variable([1024])

 h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

 # Dropout - controls the complexity of the model, prevents co-adaptation of
 # features.
 with tf.name_scope('dropout'):
 keep_prob = tf.placeholder(tf.float32, name='ratio')
 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

 # Map the 1024 features to 10 classes, one for each digit
 with tf.name_scope('out'):
 W_fc2 = weight_variable([1024, 10])
 b_fc2 = bias_variable([10])

 y_conv = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, name='fc2')
 return y_conv, keep_prob

def conv2d(x, W):
 """conv2d returns a 2d convolution layer with full stride."""
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
 """max_pool_2x2 downsamples a feature map by 2X."""
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
   strides=[1, 2, 2, 1], padding='SAME')

def weight_variable(shape):
 """weight_variable generates a weight variable of a given shape."""
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)

def bias_variable(shape):
 """bias_variable generates a bias variable of a given shape."""
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

def main(_):
 # Import data
 mnist = input_data.read_data_sets(FLAGS.data_dir)

 # Create the model
 with tf.name_scope('input'):
 x = tf.placeholder(tf.float32, [None, 784], name='x')

 # Define loss and optimizer
 y_ = tf.placeholder(tf.int64, [None])

 # Build the graph for the deep net
 y_conv, keep_prob = deepnn(x)

 with tf.name_scope('loss'):
 cross_entropy = tf.losses.sparse_softmax_cross_entropy(
 labels=y_, logits=y_conv)
 cross_entropy = tf.reduce_mean(cross_entropy)

 with tf.name_scope('adam_optimizer'):
 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

 with tf.name_scope('accuracy'):
 correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
 correct_prediction = tf.cast(correct_prediction, tf.float32)
 accuracy = tf.reduce_mean(correct_prediction)

 graph_location = './model'
 print('Saving graph to: %s' % graph_location)
 train_writer = tf.summary.FileWriter(graph_location)
 train_writer.add_graph(tf.get_default_graph())

 saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # save pb file and ckpt file
 #print('save pb file and ckpt file')
 #tf.train.write_graph(sess.graph_def, graph_location, "graph.pb", as_text=False)
 #checkpoint_path = os.path.join(graph_location, "model.ckpt")
 #saver.save(sess, checkpoint_path, global_step=max_step)

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
 parser = argparse.ArgumentParser()
 parser.add_argument('--data_dir', type=str,
   default='./data',
   help='Directory for storing input data')
 FLAGS, unparsed = parser.parse_known_args()
 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

加载模型进行推理

上一节已经训练并导出了frozen_graph.pb。

这一节把它运行起来。

加载模型

下方的代码用来加载模型。推理时计算图里共两个placeholder需要填充数据,一个是图片(这不废话吗),一个是drouout_ratio,drouout_ratio用一个常量作为输入,后续就只需要输入图片了。

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

input_map参数并不是必须的。如果不用input_map,可以在run之前用tf.get_default_graph().get_tensor_by_name获取tensor的句柄。但是我觉得这种方法不是很友好,我这里没用这种方法。

注意input_map里的tensor名字是和搭计算图时的name_scope和op名字有关的,而且后面要补一个‘:0'(这点我还没细究)。

同时要注意,newInput_X的形状是[None, 784],第一维是batch大小,推理时和训练要一致。

(我用的是mnist图片,训练时每个bacth的形状是[batchsize, 784],每个图片是28x28)

运行模型

我是一张张图片单独测试的,运行模型之前先把图片变为[1, 784],以符合newInput_X的维数。

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

注意模型的输出是一个长度为10的一维数组,也就是计算图里全连接的输出。这里没有softmax,只要取最大值的下标即可得到结果。

输出结果:

full_path:./test_images/97_7.jpg
image_np_expanded shape:(1, 784)
result:[-1340.37145996 -283.72436523 1305.03320312 437.6053772 -413.69961548
 -1218.08166504 -1004.83807373 1953.33984375 42.00457001 -504.43829346]
result:7

full_path:./test_images/98_6.jpg
image_np_expanded shape:(1, 784)
result:[ 567.4041748 -550.20904541 623.83496094 -1152.56884766 -217.92695618
 1033.45239258 2496.44750977 -1139.23620605 -5.64091825 -615.28491211]
result:6

full_path:./test_images/99_9.jpg
image_np_expanded shape:(1, 784)
result:[ -532.26409912 -1429.47277832 -368.58096313 505.82876587 358.42163086
 -317.48199463 -1108.6829834 1198.08752441 289.12286377 3083.52539062]
result:9

加载模型进行推理的完整代码

import sys
import os
import cv2
import numpy as np
import tensorflow as tf
test_image_dir = './test_images/'

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 #output = tf.import_graph_def(graph_def)
 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

以上这篇tensorflow 20:搭网络,导出模型,运行模型的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用md5sum检查目录中相同文件代码分享
Feb 02 Python
Python3控制路由器——使用requests重启极路由.py
May 11 Python
python目录与文件名操作例子
Aug 28 Python
python检测主机的连通性并记录到文件的实例
Jun 21 Python
对Python Class之间函数的调用关系详解
Jan 23 Python
Python数据类型之列表和元组的方法实例详解
Jul 08 Python
使用python实现数组、链表、队列、栈的方法
Dec 20 Python
Tensorflow 实现释放内存
Feb 03 Python
python使用Geany编辑器配置方法
Feb 21 Python
对python中arange()和linspace()的区别说明
May 03 Python
如何使用python写截屏小工具
Sep 29 Python
python requests模块的使用示例
Apr 07 Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
You might like
德生H-501的评价与改造
2021/03/02 无线电
微信access_token的获取开发示例
2015/04/16 PHP
PHP Swoole异步Redis客户端实现方法示例
2019/10/24 PHP
一个判断email合法性的函数[非正则]
2008/12/09 Javascript
JS判断当前日期是否大于某个日期的实现代码
2012/09/02 Javascript
jQuery阻止事件冒泡具体实现
2013/10/11 Javascript
从数组中随机取x条不重复数据的JS代码
2013/12/24 Javascript
一个支付页面DEMO附截图
2014/07/22 Javascript
jQuery中odd选择器的定义和用法
2014/12/23 Javascript
jquery判断复选框是否被选中的方法
2015/10/16 Javascript
jQuery中layer分页器的使用
2017/03/13 Javascript
js实现鼠标拖动功能
2017/03/20 Javascript
angular中实现li或者某个元素点击变色的两种方法
2017/07/27 Javascript
通过fastclick源码分析彻底解决tap“点透”
2017/12/24 Javascript
react 中父组件与子组件双向绑定问题
2019/05/20 Javascript
Vue动态创建注册component的实例代码
2019/06/14 Javascript
vue-cli history模式实现tomcat部署报404的解决方式
2019/09/06 Javascript
js实现坦克移动小游戏
2019/10/28 Javascript
[00:09]DOTA2新版本PA至宝特效动作展示
2014/11/19 DOTA
[01:03:18]DOTA2-DPC中国联赛 正赛 RNG vs Dynasty BO3 第一场 1月29日
2021/03/11 DOTA
python用于url解码和中文解析的小脚本(python url decoder)
2013/08/11 Python
python复制文件代码实现
2013/12/23 Python
按日期打印Python的Tornado框架中的日志的方法
2015/05/02 Python
Python图片裁剪实例代码(如头像裁剪)
2017/06/21 Python
Python 访问限制 private public的详细介绍
2018/10/16 Python
python 实现交换两个列表元素的位置示例
2019/06/26 Python
在Pytorch中计算卷积方法的区别详解(conv2d的区别)
2020/01/03 Python
Python使用多进程运行含有任意个参数的函数
2020/05/02 Python
Django 允许局域网中的机器访问你的主机操作
2020/05/13 Python
详解CSS的border边框属性及其在CSS3中的新特性
2016/05/10 HTML / CSS
法国春天百货官网:Printemps.com
2020/06/29 全球购物
师范应届毕业生自荐信
2013/11/18 职场文书
师范生见习报告
2014/10/31 职场文书
上诉答辩状范文
2015/05/22 职场文书
Python实现Hash算法
2022/03/18 Python
Python尝试实现蒙特卡罗模拟期权定价
2022/04/21 Python