tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中给List添加元素的4种方法分享
Nov 28 Python
python实现马耳可夫链算法实例分析
May 20 Python
Python实现建立SSH连接的方法
Jun 03 Python
tensorflow实现对图片的读取的示例代码
Feb 12 Python
Python 实现选择排序的算法步骤
Apr 22 Python
Python使用matplotlib绘制随机漫步图
Aug 27 Python
[原创]Python入门教程5. 字典基本操作【定义、运算、常用函数】
Nov 01 Python
django rest framework vue 实现用户登录详解
Jul 29 Python
Flask框架模板继承实现方法分析
Jul 31 Python
Django关于admin的使用技巧和知识点
Feb 10 Python
python 中的paramiko模块简介及安装过程
Feb 29 Python
Python基础之元类详解
Apr 29 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
php-accelerator网站加速PHP缓冲的方法
2008/07/30 PHP
php adodb连接mssql解决乱码问题
2009/06/12 PHP
php Smarty date_format [格式化时间日期]
2010/03/15 PHP
PHP中去掉字符串首尾空格的方法
2012/05/19 PHP
php中防止SQL注入的最佳解决方法
2013/04/25 PHP
PHP利用header跳转失效的解决方法
2014/10/24 PHP
php+ajax制作无刷新留言板
2015/10/27 PHP
PHP面向对象五大原则之开放-封闭原则(OCP)详解
2018/04/04 PHP
Yii Framework框架使用PHPExcel组件的方法示例
2019/07/24 PHP
(推荐一个超好的JS函数库)S.Sams Lifexperience ScriptClassLib
2007/04/29 Javascript
jQuery 插件开发指南
2014/11/14 Javascript
jquery简单实现图片切换效果的方法
2015/05/12 Javascript
AngularJS表单详解及示例代码
2016/08/17 Javascript
Vue.js之slot深度复制详解
2017/03/10 Javascript
bootstrap轮播模板使用方法详解
2017/11/17 Javascript
用Node编写RESTful API接口的示例代码
2018/07/04 Javascript
layui前端框架之table表数据的刷新方法
2018/08/17 Javascript
JS实现倒序输出的几种常用方法示例
2019/04/13 Javascript
vue跳转同一个组件,参数不同,页面接收值只接收一次的解决方法
2019/11/05 Javascript
js中火星坐标、百度坐标、WGS84坐标转换实现方法示例
2020/03/02 Javascript
原生JavaScript实现五子棋游戏
2020/11/09 Javascript
[46:25]DOTA2上海特级锦标赛主赛事日 - 4 败者组第五轮 MVP.Phx VS EG第二局
2016/03/05 DOTA
Python设计模式之观察者模式简单示例
2018/01/10 Python
python实现可视化动态CPU性能监控
2018/06/21 Python
Python 给某个文件名添加时间戳的方法
2018/10/16 Python
python3 打开外部程序及关闭的示例
2018/11/06 Python
python自动生成证件号的方法示例
2021/01/14 Python
使用python tkinter开发一个爬取B站直播弹幕工具的实现代码
2021/02/07 Python
一款纯css3实现的鼠标悬停动画按钮
2014/12/29 HTML / CSS
法人任命书范本
2014/06/04 职场文书
大学专科自荐信
2014/06/17 职场文书
关于读书的演讲稿800字
2014/08/27 职场文书
物业工程部主管岗位职责
2015/04/16 职场文书
Python进度条的使用
2021/05/17 Python
Mysql Online DDL的使用详解
2021/05/20 MySQL
如何理解python接口自动化之logging日志模块
2021/06/15 Python