tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python分割TXT文件成4K的TXT文件
May 23 Python
python实现文件名批量替换和内容替换
Mar 20 Python
python机器学习之神经网络(三)
Dec 20 Python
python调用支付宝支付接口流程
Aug 15 Python
python3中的eval和exec的区别与联系
Oct 10 Python
解决windows下python3使用multiprocessing.Pool出现的问题
Apr 08 Python
python如何调用java类
Jul 05 Python
Python定义一个Actor任务
Jul 29 Python
Python实现Canny及Hough算法代码实例解析
Aug 06 Python
Django-Scrapy生成后端json接口的方法示例
Oct 06 Python
详解selenium + chromedriver 被反爬的解决方法
Oct 28 Python
解决python3.6用cx_Oracle库连接Oracle的问题
Dec 07 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
一个ORACLE分页程序,挺实用的.
2006/10/09 PHP
codeigniter框架The URI you submitted has disallowed characters错误解决方法
2014/05/06 PHP
PHP+Ajax检测用户名或邮件注册时是否已经存在实例教程
2014/08/23 PHP
浅谈PHP拦截器之__set()与__get()的理解与使用方法
2016/10/18 PHP
php使用自定义函数实现汉字分割替换功能示例
2017/01/30 PHP
php-fpm重启导致的程序执行中断问题详解
2019/04/29 PHP
关于Laravel参数验证的一些疑与惑
2019/11/19 PHP
js常用排序实现代码
2010/12/28 Javascript
jquery 删除字符串最后一个字符的方法解析
2014/02/11 Javascript
JavaScript中匿名、命名函数的性能测试
2014/09/04 Javascript
javascript操作Cookie(设置、读取、删除)方法详解
2015/03/18 Javascript
JavaScript获取当前日期是星期几的方法
2015/04/06 Javascript
Javascript 是你的高阶函数(高级应用)
2015/06/15 Javascript
详解使用Vue.Js结合Jquery Ajax加载数据的两种方式
2017/01/10 Javascript
Angularjs+bootstrap+table多选(全选)支持单击行选中实现编辑、删除功能
2017/03/27 Javascript
浅谈struts1 & jquery form 文件异步上传
2017/05/25 jQuery
jQuery实现的页面详情展开收起功能示例
2018/06/11 jQuery
JQueryDOM之样式操作
2019/03/27 jQuery
JS实现动态倒计时功能(天数、时、分、秒)
2019/12/12 Javascript
[01:22:10]Ti4 循环赛第二日 DK vs Empire
2014/07/11 DOTA
python和ruby,我选谁?
2017/09/13 Python
将python文件打包成EXE应用程序的方法
2019/05/22 Python
python打包exe开机自动启动的实例(windows)
2019/06/28 Python
python shutil文件操作工具使用实例分析
2019/12/25 Python
用python拟合等角螺线的实现示例
2019/12/27 Python
Python中使用filter过滤列表的一个小技巧分享
2020/05/02 Python
移动端html5 meta标签的神奇功效
2016/01/06 HTML / CSS
基于HTML5 WebGL的3D机房的示例
2018/03/16 HTML / CSS
美国Lolё官网:购买大胆而美丽的女性运动服装
2017/05/22 全球购物
瑞士网球商店:Tennis-Point
2020/03/12 全球购物
给导游的表扬信
2014/01/10 职场文书
最新茶叶店创业计划书
2014/01/14 职场文书
产品促销活动策划书
2014/01/15 职场文书
群众对十八届四中全会的期盼
2014/10/17 职场文书
Ajax请求超时与网络异常处理图文详解
2021/05/23 Javascript
mysql 索引合并的使用
2021/08/30 MySQL