tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从零学python系列之浅谈pickle模块封装和拆封数据对象的方法
May 23 Python
Python远程桌面协议RDPY安装使用介绍
Apr 15 Python
分析Python中设计模式之Decorator装饰器模式的要点
Mar 02 Python
Python简单生成8位随机密码的方法
May 24 Python
Python数据结构之双向链表的定义与使用方法示例
Jan 16 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 Python
anaconda如何查看并管理python环境
Jul 05 Python
基于Python和PyYAML读取yaml配置文件数据
Jan 13 Python
python matplotlib包图像配色方案分享
Mar 14 Python
Django-imagekit的使用详解
Jul 06 Python
python中如何使用虚拟环境
Oct 14 Python
把Anaconda中的环境导入到Pycharm里面的方法步骤
Oct 30 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
php a simple smtp class
2007/11/26 PHP
解析将多维数组转换为支持curl提交的一维数组格式
2013/07/08 PHP
php实现的美国50个州选择列表实例
2015/04/20 PHP
初识ThinkPHP控制器
2016/04/07 PHP
实例讲解通过​PHP创建数据库
2019/01/20 PHP
PHP PDOStatement::setAttribute讲解
2019/02/01 PHP
PHP+redis实现微博的推模型案例分析
2019/07/10 PHP
PHP pthreads v3下的Volatile简介与使用方法示例
2020/02/21 PHP
Javascript的闭包
2009/12/31 Javascript
Kibo 用于处理键盘事件的Javascript工具库
2011/10/28 Javascript
使用js简单实现了tree树菜单
2013/11/20 Javascript
javascript实现简单的页面右下角提示信息框
2015/07/31 Javascript
Bootstrap学习笔记之css样式设计(2)
2016/06/07 Javascript
JS实现微信弹出搜索框 多条件查询功能
2016/12/13 Javascript
基于jPlayer三分屏的制作方法
2016/12/21 Javascript
Bootstrap组合上、下拉框简单实现代码
2017/03/06 Javascript
JavaScript 完成注册页面表单校验的实例
2017/08/19 Javascript
再谈Angular4 脏值检测(性能优化)
2018/04/23 Javascript
vue addRoutes实现动态权限路由菜单的示例
2018/05/15 Javascript
vue input输入框模糊查询的示例代码
2018/05/22 Javascript
JavaScript变速动画函数封装添加任意多个属性
2019/04/03 Javascript
详解Vue Cli浏览器兼容性实践
2020/06/08 Javascript
jQuery实现朋友圈查看图片
2020/09/11 jQuery
Cython 三分钟入门教程
2009/09/17 Python
在Heroku云平台上部署Python的Django框架的教程
2015/04/20 Python
Python 爬虫的工具列表大全
2016/01/31 Python
Python实现读取邮箱中的邮件功能示例【含文本及附件】
2017/08/05 Python
遗传算法之Python实现代码
2017/10/10 Python
Django中url的反向查询的方法
2018/03/14 Python
python多线程之事件Event的使用详解
2018/04/27 Python
解决pyqt5异常退出无提示信息的问题
2020/04/08 Python
Django nginx配置实现过程详解
2020/09/10 Python
北京麒麟网信息技术有限公司网络游戏测试面试题
2013/09/28 面试题
出国导师推荐信
2015/03/25 职场文书
八年级地理课件资料及考点知识分享
2019/08/30 职场文书
游戏《铁拳》动画化!2022年年内播出
2022/03/21 日漫