tensorflow 20:搭网络,导出模型,运行模型的实例


Posted in Python onMay 26, 2020

概述

以前自己都利用别人搭好的工程,修改过来用,很少把模型搭建、导出模型、加载模型运行走一遍,搞了一遍才知道这个事情也不是那么简单的。

搭建模型和导出模型

参考《TensorFlow固化模型》,导出固化的模型有两种方式.

方式1:导出pb图结构和ckpt文件,然后用 freeze_graph 工具冻结生成一个pb(包含结构和参数)

在我的代码里测试了生成pb图结构和ckpt文件,但是没接着往下走,感觉有点麻烦。我用的是第二种方法。

注意我这里只在最后保存了一次ckpt,实际应该在训练中每隔一段时间就保存一次的。

saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # 保存pb和ckpt
 print('save pb file and ckpt file')
 tf.train.write_graph(sess.graph_def, graph_location, "graph.pb",as_text=False)
 checkpoint_path = os.path.join(graph_location, "model.ckpt")
 saver.save(sess, checkpoint_path, global_step=max_step)

方式2:convert_variables_to_constants

我实际使用的就是这种方法。

看名字也知道,就是把变量转化为常量保存,这样就可以愉快的加载使用了。

注意这里需要指明保存的输出节点,我的输出节点为'out/fc2'(我猜测会根据输出节点的依赖推断哪些部分是训练用到的,推理时用不到)。关于输出节点的名字是有规律的,其中out是一个name_scope名字,fc2是op节点的名字。

with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 # 固化模型
 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

上述代码会在训练后把训练好的计算图和参数保存到frozen_graph.pb文件。后续就可以用这个模型来测试图片了。

方式2的完整训练和保存模型代码

主要看main函数就行。另外注意deepnn函数最后节点的名字。

"""A deep MNIST classifier using convolutional layers.

See extensive documentation at
https://www.tensorflow.org/get_started/mnist/pros
"""
# Disable linter warnings to maintain consistency with tutorial.
# pylint: disable=invalid-name
# pylint: disable=g-bad-import-order

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile
import os

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants

import tensorflow as tf
FLAGS = None

def deepnn(x):
 """deepnn builds the graph for a deep net for classifying digits.

 Args:
 x: an input tensor with the dimensions (N_examples, 784), where 784 is the
 number of pixels in a standard MNIST image.

 Returns:
 A tuple (y, keep_prob). y is a tensor of shape (N_examples, 10), with values
 equal to the logits of classifying the digit into one of 10 classes (the
 digits 0-9). keep_prob is a scalar placeholder for the probability of
 dropout.
 """
 # Reshape to use within a convolutional neural net.
 # Last dimension is for "features" - there is only one here, since images are
 # grayscale -- it would be 3 for an RGB image, 4 for RGBA, etc.
 with tf.name_scope('reshape'):
 x_image = tf.reshape(x, [-1, 28, 28, 1])

 # First convolutional layer - maps one grayscale image to 32 feature maps.
 with tf.name_scope('conv1'):
 W_conv1 = weight_variable([5, 5, 1, 32])
 b_conv1 = bias_variable([32])
 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)

 # Pooling layer - downsamples by 2X.
 with tf.name_scope('pool1'):
 h_pool1 = max_pool_2x2(h_conv1)

 # Second convolutional layer -- maps 32 feature maps to 64.
 with tf.name_scope('conv2'):
 W_conv2 = weight_variable([5, 5, 32, 64])
 b_conv2 = bias_variable([64])
 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

 # Second pooling layer.
 with tf.name_scope('pool2'):
 h_pool2 = max_pool_2x2(h_conv2)

 # Fully connected layer 1 -- after 2 round of downsampling, our 28x28 image
 # is down to 7x7x64 feature maps -- maps this to 1024 features.
 with tf.name_scope('fc1'):
 W_fc1 = weight_variable([7 * 7 * 64, 1024])
 b_fc1 = bias_variable([1024])

 h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

 # Dropout - controls the complexity of the model, prevents co-adaptation of
 # features.
 with tf.name_scope('dropout'):
 keep_prob = tf.placeholder(tf.float32, name='ratio')
 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

 # Map the 1024 features to 10 classes, one for each digit
 with tf.name_scope('out'):
 W_fc2 = weight_variable([1024, 10])
 b_fc2 = bias_variable([10])

 y_conv = tf.add(tf.matmul(h_fc1_drop, W_fc2), b_fc2, name='fc2')
 return y_conv, keep_prob

def conv2d(x, W):
 """conv2d returns a 2d convolution layer with full stride."""
 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
 """max_pool_2x2 downsamples a feature map by 2X."""
 return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
   strides=[1, 2, 2, 1], padding='SAME')

def weight_variable(shape):
 """weight_variable generates a weight variable of a given shape."""
 initial = tf.truncated_normal(shape, stddev=0.1)
 return tf.Variable(initial)

def bias_variable(shape):
 """bias_variable generates a bias variable of a given shape."""
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

def main(_):
 # Import data
 mnist = input_data.read_data_sets(FLAGS.data_dir)

 # Create the model
 with tf.name_scope('input'):
 x = tf.placeholder(tf.float32, [None, 784], name='x')

 # Define loss and optimizer
 y_ = tf.placeholder(tf.int64, [None])

 # Build the graph for the deep net
 y_conv, keep_prob = deepnn(x)

 with tf.name_scope('loss'):
 cross_entropy = tf.losses.sparse_softmax_cross_entropy(
 labels=y_, logits=y_conv)
 cross_entropy = tf.reduce_mean(cross_entropy)

 with tf.name_scope('adam_optimizer'):
 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

 with tf.name_scope('accuracy'):
 correct_prediction = tf.equal(tf.argmax(y_conv, 1), y_)
 correct_prediction = tf.cast(correct_prediction, tf.float32)
 accuracy = tf.reduce_mean(correct_prediction)

 graph_location = './model'
 print('Saving graph to: %s' % graph_location)
 train_writer = tf.summary.FileWriter(graph_location)
 train_writer.add_graph(tf.get_default_graph())

 saver = tf.train.Saver(max_to_keep=5)
 #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)
 
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

 max_step = 2000
 for i in range(max_step):
 batch = mnist.train.next_batch(50)
 if i % 100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
  x: batch[0], y_: batch[1], keep_prob: 1.0})
 print('step %d, training accuracy %g' % (i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
 
 print('test accuracy %g' % accuracy.eval(feed_dict={
 x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
 
 # save pb file and ckpt file
 #print('save pb file and ckpt file')
 #tf.train.write_graph(sess.graph_def, graph_location, "graph.pb", as_text=False)
 #checkpoint_path = os.path.join(graph_location, "model.ckpt")
 #saver.save(sess, checkpoint_path, global_step=max_step)

 print('save frozen file')
 pb_path = os.path.join(graph_location, 'frozen_graph.pb')
 print('pb_path:{}'.format(pb_path))

 output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['out/fc2'])
 with tf.gfile.FastGFile(pb_path, mode='wb') as f:
 f.write(output_graph_def.SerializeToString())

if __name__ == '__main__':
 parser = argparse.ArgumentParser()
 parser.add_argument('--data_dir', type=str,
   default='./data',
   help='Directory for storing input data')
 FLAGS, unparsed = parser.parse_known_args()
 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

加载模型进行推理

上一节已经训练并导出了frozen_graph.pb。

这一节把它运行起来。

加载模型

下方的代码用来加载模型。推理时计算图里共两个placeholder需要填充数据,一个是图片(这不废话吗),一个是drouout_ratio,drouout_ratio用一个常量作为输入,后续就只需要输入图片了。

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())

 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

input_map参数并不是必须的。如果不用input_map,可以在run之前用tf.get_default_graph().get_tensor_by_name获取tensor的句柄。但是我觉得这种方法不是很友好,我这里没用这种方法。

注意input_map里的tensor名字是和搭计算图时的name_scope和op名字有关的,而且后面要补一个‘:0'(这点我还没细究)。

同时要注意,newInput_X的形状是[None, 784],第一维是batch大小,推理时和训练要一致。

(我用的是mnist图片,训练时每个bacth的形状是[batchsize, 784],每个图片是28x28)

运行模型

我是一张张图片单独测试的,运行模型之前先把图片变为[1, 784],以符合newInput_X的维数。

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

注意模型的输出是一个长度为10的一维数组,也就是计算图里全连接的输出。这里没有softmax,只要取最大值的下标即可得到结果。

输出结果:

full_path:./test_images/97_7.jpg
image_np_expanded shape:(1, 784)
result:[-1340.37145996 -283.72436523 1305.03320312 437.6053772 -413.69961548
 -1218.08166504 -1004.83807373 1953.33984375 42.00457001 -504.43829346]
result:7

full_path:./test_images/98_6.jpg
image_np_expanded shape:(1, 784)
result:[ 567.4041748 -550.20904541 623.83496094 -1152.56884766 -217.92695618
 1033.45239258 2496.44750977 -1139.23620605 -5.64091825 -615.28491211]
result:6

full_path:./test_images/99_9.jpg
image_np_expanded shape:(1, 784)
result:[ -532.26409912 -1429.47277832 -368.58096313 505.82876587 358.42163086
 -317.48199463 -1108.6829834 1198.08752441 289.12286377 3083.52539062]
result:9

加载模型进行推理的完整代码

import sys
import os
import cv2
import numpy as np
import tensorflow as tf
test_image_dir = './test_images/'

graph_location = './model'
pb_path = os.path.join(graph_location, 'frozen_graph.pb')
print('pb_path:{}'.format(pb_path))

newInput_X = tf.placeholder(tf.float32, [None, 784], name="X")
drouout_ratio = tf.constant(1., name="drouout")
with open(pb_path, 'rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 #output = tf.import_graph_def(graph_def)
 output = tf.import_graph_def(graph_def,
     input_map={'input/x:0': newInput_X, 'dropout/ratio:0':drouout_ratio},
     return_elements=['out/fc2:0'])

with tf.Session( ) as sess:
 file_list = os.listdir(test_image_dir)
 
 # 遍历文件
 for file in file_list:
 full_path = os.path.join(test_image_dir, file)
 print('full_path:{}'.format(full_path))
 
 # 只要黑白的,大小控制在(28,28)
 img = cv2.imread(full_path, cv2.IMREAD_GRAYSCALE )
 res_img = cv2.resize(img,(28,28),interpolation=cv2.INTER_CUBIC) 
 # 变成长784的一维数据
 new_img = res_img.reshape((784))
 
 # 增加一个维度,变为 [1, 784]
 image_np_expanded = np.expand_dims(new_img, axis=0)
 image_np_expanded.astype('float32') # 类型也要满足要求
 print('image_np_expanded shape:{}'.format(image_np_expanded.shape))
 
 # 注意注意,我要调用模型了
 result = sess.run(output, feed_dict={newInput_X: image_np_expanded})
 
 # 出来的结果去掉没用的维度
 result = np.squeeze(result)
 print('result:{}'.format(result))
 #print('result:{}'.format(sess.run(output, feed_dict={newInput_X: image_np_expanded})))
 
 # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字
 print('result:{}'.format( (np.where(result==np.max(result)))[0][0] ))

以上这篇tensorflow 20:搭网络,导出模型,运行模型的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用OpenCV进行人脸检测的例子
Apr 18 Python
python获取局域网占带宽最大3个ip的方法
Jul 09 Python
python实现计算倒数的方法
Jul 11 Python
Python中使用urllib2模块编写爬虫的简单上手示例
Jan 20 Python
python 处理dataframe中的时间字段方法
Apr 10 Python
Python爬虫之pandas基本安装与使用方法示例
Aug 08 Python
Python中super函数用法实例分析
Mar 18 Python
python实现的汉诺塔算法示例
Oct 23 Python
Ubuntu下Python+Flask分分钟搭建自己的服务器教程
Nov 19 Python
使用python检查yaml配置文件是否符合要求
Apr 09 Python
使用PyWeChatSpy自动回复微信拍一拍功能的实现代码
Jul 02 Python
PyTorch中的torch.cat简单介绍
Mar 17 Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
tensorflow实现从.ckpt文件中读取任意变量
May 26 #Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 #Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
May 26 #Python
tensorflow从ckpt和从.pb文件读取变量的值方式
May 26 #Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 #Python
You might like
Zend Framework分发器用法示例
2016/12/11 PHP
PHP获取当前URL路径的处理方法(适用于多条件筛选列表)
2017/02/10 PHP
PHP数组内存利用率低和弱类型详细解读
2017/08/10 PHP
thinkPHP5框架实现基于ajax的分页功能示例
2018/06/12 PHP
TP5框架实现签到功能的方法分析
2020/04/05 PHP
JS IE和FF兼容性问题汇总
2009/02/09 Javascript
JavaScript类和继承 constructor属性
2010/03/04 Javascript
jQuery $.each的用法说明
2010/03/22 Javascript
火狐4、谷歌12不支持Jquery Validator的解决方法分享
2011/06/20 Javascript
Prototype源码浅析 String部分(一)之有关indexOf优化
2012/01/15 Javascript
终于解决了IE8不支持数组的indexOf方法
2013/04/03 Javascript
JQuery设置和去除disabled属性的5种方法总结
2013/05/16 Javascript
Javascript获取当前日期的农历日期代码
2014/10/08 Javascript
jQuery的animate函数实现图文切换动画效果
2015/05/03 Javascript
javascript 继承学习心得总结
2016/03/17 Javascript
web前端开发中常见的多列布局解决方案整理(一定要看)
2017/10/15 Javascript
vantUI 获得piker选中值的自定义ID操作
2020/11/04 Javascript
用Python生成器实现微线程编程的教程
2015/04/13 Python
如何在Python中编写并发程序
2016/02/27 Python
Python黑魔法@property装饰器的使用技巧解析
2016/06/16 Python
python 实现删除文件或文件夹实例详解
2016/12/04 Python
Python实现判断字符串中包含某个字符的判断函数示例
2018/01/08 Python
Python使用numpy产生正态分布随机数的向量或矩阵操作示例
2018/08/22 Python
python 与服务器的共享文件夹交互方法
2018/12/27 Python
python的debug实用工具 pdb详解
2019/07/12 Python
HTML5等待加载动画效果
2017/07/27 HTML / CSS
哈利波特商店:Harry Potter Shop
2018/11/30 全球购物
致跳远、跳高运动员广播稿
2014/01/09 职场文书
个人工作表现评语
2014/04/30 职场文书
小学教师师德演讲稿
2014/05/06 职场文书
专家推荐信模板
2014/05/09 职场文书
疾病捐款倡议书
2014/05/13 职场文书
幼儿园六一儿童节演讲稿
2015/03/19 职场文书
证劵公司反洗钱宣传活动总结
2015/05/08 职场文书
跳高加油稿
2015/07/21 职场文书
详解MySQL多版本并发控制机制(MVCC)源码
2021/06/23 MySQL