tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常见数据结构详解
Jul 24 Python
Python使用poplib模块和smtplib模块收发电子邮件的教程
Jul 02 Python
python字符串str和字节数组相互转化方法
Mar 18 Python
python利用MethodType绑定方法到类示例代码
Aug 27 Python
python计算日期之间的放假日期
Jun 05 Python
python 实时得到cpu和内存的使用情况方法
Jun 11 Python
Python错误处理操作示例
Jul 18 Python
dpn网络的pytorch实现方式
Jan 14 Python
Pycharm中Python环境配置常见问题解析
Jan 16 Python
Python Pillow(PIL)库的用法详解
Sep 19 Python
Opencv python 图片生成视频的方法示例
Nov 18 Python
Python可视化学习之seaborn绘制矩阵图详解
Feb 24 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
ASP和PHP都是可以删除自身的
2007/04/09 PHP
PHP与SQL注入攻击[一]
2007/04/17 PHP
PHP开发的一些注意点总结
2010/10/12 PHP
PHP实现文件上传与下载实例与总结
2016/03/13 PHP
PHP中关键字interface和implements详解
2017/06/14 PHP
如何在标题栏显示框架内页面的标题
2007/02/03 Javascript
JavaScript 异步调用框架 (Part 3 - 代码实现)
2009/08/04 Javascript
在javascript中关于节点内容加强
2013/04/11 Javascript
JQuery结合CSS操作打印样式的方法
2013/12/24 Javascript
javascript+HTML5 Canvas绘制转盘抽奖
2020/05/16 Javascript
使用jquery.qrcode.js生成二维码插件
2016/10/17 Javascript
Nodejs 搭建简单的Web服务器详解及实例
2016/11/30 NodeJs
详解Angular 4.x 动态创建组件
2017/04/25 Javascript
Angular2.js实现表单验证详解
2017/06/23 Javascript
浅谈基于Vue.js的移动组件库cube-ui
2017/12/20 Javascript
vue+elementUI实现图片上传功能
2019/08/20 Javascript
vue获取data数据改变前后的值方法
2019/11/07 Javascript
python中文乱码的解决方法
2013/11/04 Python
用Python写王者荣耀刷金币脚本
2017/12/21 Python
人生苦短我用python python如何快速入门?
2018/03/12 Python
Python实现字符串的逆序 C++字符串逆序算法
2020/05/28 Python
python日志模块logbook使用方法
2019/09/19 Python
Python 没有main函数的原因
2020/07/10 Python
python3实现飞机大战
2020/11/29 Python
python3实现名片管理系统(控制台版)
2020/11/29 Python
使用CSS3和Checkbox实现JQuery的一些效果
2015/08/03 HTML / CSS
html5页面结构_动力节点Java学院整理
2017/07/10 HTML / CSS
香港零食网购:上仓胃子
2020/06/08 全球购物
自我鉴定模板
2013/10/29 职场文书
基层党支部整改方案
2014/10/25 职场文书
学生个人总结范文
2015/02/15 职场文书
小学生六年级作文之关于感恩
2019/08/16 职场文书
vue使用wavesurfer.js解决音频可视化播放问题
2022/04/04 Vue.js
Python写情书? 10行代码展示如何把情书写在她的照片里
2022/04/21 Python
Java异常体系非正常停止和分类
2022/06/14 Java/Android
MySQL远程无法连接的一些常见原因总结
2022/09/23 MySQL