tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 不同对象比较大小示例探讨
Aug 21 Python
浅谈编码,解码,乱码的问题
Dec 30 Python
python八大排序算法速度实例对比
Dec 06 Python
使用pycharm设置控制台不换行的操作方法
Jan 19 Python
python的pytest框架之命令行参数详解(上)
Jun 27 Python
如何安装并使用conda指令管理python环境
Jul 10 Python
Django项目使用CircleCI的方法示例
Jul 14 Python
python openpyxl使用方法详解
Jul 18 Python
在django中使用apscheduler 执行计划任务的实现方法
Feb 11 Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 Python
彻底解决Python包下载慢问题
Nov 15 Python
Python 将代码转换为可执行文件脱离python环境运行(步骤详解)
Jan 25 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
五个PHP程序员工具
2008/05/26 PHP
memcached 和 mysql 主从环境下php开发代码详解
2010/05/16 PHP
php daddslashes()和 saddslashes()有哪些区别分析
2012/10/26 PHP
php版微信公众平台接口开发之智能回复开发教程
2016/09/22 PHP
CI框架中类的自动加载问题分析
2016/11/21 PHP
PHP 文件锁与进程锁的使用示例
2017/08/07 PHP
PHP中遍历数组的三种常用方法实例分析
2019/06/24 PHP
Thinkphp 框架配置操作之配置加载与读取配置实例分析
2020/05/15 PHP
常用js脚本
2006/12/03 Javascript
javascript中的undefined 与 null 的区别  补充篇
2010/03/17 Javascript
jquery与prototype框架的详细对比
2013/11/21 Javascript
node.js解决获取图片真实文件类型的问题
2014/12/20 Javascript
javascript中的五种基本数据类型
2015/08/26 Javascript
AngularJS中实现显示或隐藏动画效果的方式总结
2015/12/31 Javascript
bootstrap table之通用方法( 时间控件,导出,动态下拉框, 表单验证 ,选中与获取信息)代码分享
2017/01/24 Javascript
AngularJS 实现购物车全选反选功能
2017/10/24 Javascript
three.js实现3D视野缩放效果
2017/11/16 Javascript
vue实现树形菜单效果
2018/03/19 Javascript
vue.js前后端数据交互之提交数据操作详解
2018/04/24 Javascript
jQuery实现的页面详情展开收起功能示例
2018/06/11 jQuery
Electron实现应用打包、自动升级过程解析
2020/07/07 Javascript
python学习 流程控制语句详解
2016/06/01 Python
Python遍历目录并批量更换文件名和目录名的方法
2016/09/19 Python
Python使用sftp实现上传和下载功能(实例代码)
2017/03/14 Python
Python实现的简单计算器功能详解
2018/08/25 Python
python+opencv 读取文件夹下的所有图像并批量保存ROI的方法
2019/01/10 Python
Python列表如何更新值
2020/05/27 Python
使paramiko库执行命令时在给定的时间强制退出功能的实现
2021/03/03 Python
详解CSS的border边框属性及其在CSS3中的新特性
2016/05/10 HTML / CSS
德国领先的大尺码和超大尺码男装在线零售商:Bigtex
2019/06/22 全球购物
写给女朋友的道歉信
2014/01/12 职场文书
学生会主席就职演讲稿
2014/01/14 职场文书
2014年国培研修感言
2014/03/09 职场文书
2015年监理工作总结范文
2015/04/07 职场文书
干货!开幕词的写作方法
2019/04/02 职场文书
《飘》英文读后感五篇
2019/10/11 职场文书