tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现dict版图遍历示例
Feb 19 Python
详解Python的Django框架中的中间件
Jul 24 Python
简单谈谈python基本数据类型
Sep 26 Python
Python supervisor强大的进程管理工具的使用
Apr 24 Python
pandas dataframe的合并实现(append, merge, concat)
Jun 24 Python
Python求两点之间的直线距离(2种实现方法)
Jul 07 Python
在django admin详情表单显示中添加自定义控件的实现
Mar 11 Python
解决启动django,浏览器显示“服务器拒绝访问”的问题
May 13 Python
Python引入多个模块及包的概念过程解析
Sep 21 Python
python 读取yaml文件的两种方法(在unittest中使用)
Dec 01 Python
Python实现区域填充的示例代码
Feb 03 Python
Python爬虫之爬取二手房信息
Apr 27 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
Look And Say 序列php实现代码
2011/05/22 PHP
使用php+apc实现上传进度条且在IE7下不显示的问题解决方法
2013/04/25 PHP
thinkphp视图模型查询提示ERR: 1146:Table 'db.pr_order_view' doesn't exist的解决方法
2014/10/30 PHP
jQuery学习之prop和attr的区别示例介绍
2013/11/15 Javascript
jquery中each遍历对象和数组示例
2014/08/05 Javascript
基于nodejs+express(4.x+)实现文件上传功能
2015/11/23 NodeJs
深入理解Java线程编程中的阻塞队列容器
2015/12/07 Javascript
three.js实现3D影院的原理的代码分析
2017/12/18 Javascript
jQuery+SpringMVC中的复选框选择与传值实例
2018/01/08 jQuery
详解Vue结合后台的列表增删改案例
2018/08/21 Javascript
vue中的ref和$refs的使用
2018/11/22 Javascript
node.js express框架实现文件上传与下载功能实例详解
2019/10/15 Javascript
webpack 如何同时输出压缩和未压缩的文件的实现步骤
2020/06/05 Javascript
[01:25:09]2014 DOTA2国际邀请赛中国区预选赛 5 23 CIS VS DT第二场
2014/05/24 DOTA
python实现根据图标提取分类应用程序实例
2014/09/28 Python
自动化Nginx服务器的反向代理的配置方法
2015/06/28 Python
Python代码缩进和测试模块示例详解
2018/05/07 Python
基于pycharm导入模块显示不存在的解决方法
2018/10/13 Python
DataFrame:通过SparkSql将scala类转为DataFrame的方法
2019/01/29 Python
Python序列对象与String类型内置方法详解
2019/10/22 Python
使用python去除图片白色像素的实例
2019/12/12 Python
Python selenium使用autoIT上传附件过程详解
2020/05/26 Python
使用CSS3 制作一个material-design 风格登录界面实例
2016/12/12 HTML / CSS
教你使用Canvas处理图片的方法
2017/11/28 HTML / CSS
美国照明、家居装饰和家具购物网站:Bellacor
2017/09/20 全球购物
制药工程专业个人求职自荐信
2014/01/25 职场文书
元旦促销方案
2014/03/15 职场文书
工作会议方案
2014/05/21 职场文书
意外死亡赔偿协议书
2014/10/14 职场文书
开幕式邀请函
2015/01/31 职场文书
2015年度党员自我评价范文
2015/03/03 职场文书
2015年社区矫正工作总结
2015/04/21 职场文书
2015年司法所工作总结
2015/04/27 职场文书
企业工会工作总结2015
2015/05/13 职场文书
MySQL高速缓存启动方法及参数详解(query_cache_size)
2021/07/01 MySQL
Mysql如何查看是否使用到索引
2022/12/24 MySQL