tensorflow 保存模型和取出中间权重例子


Posted in Python onJanuary 24, 2020

下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。

import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
  print(ele)
  print(reader.get_tensor(ele))


x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4

w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b


loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''


saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  if isTrain:
    for i in xrange(train_steps):
      sess.run(train, feed_dict={x: x_data})
      if (i + 1) % checkpoint_steps == 0:
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
  else:
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
    else:
      pass   
    print(sess.run(w))
    print(sess.run(b))
    graph_def = tf.get_default_graph().as_graph_def()
    #通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
    with tf.gfile.FastGFile('./test.pb', 'wb') as f:
      f.write(output_graph_def.SerializeToString())


with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
  with gfile.FastGFile('./test.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
  res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
  print(sess.run(res))
  print(sess.run(graph_def))

以上这篇tensorflow 保存模型和取出中间权重例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python BeautifulSoup设置页面编码的方法
Apr 03 Python
如何用itertools解决无序排列组合的问题
May 18 Python
Python2与python3中 for 循环语句基础与实例分析
Nov 20 Python
Python第三方Window模块文件的几种安装方法
Nov 22 Python
numpy.random模块用法总结
May 27 Python
Flask框架路由和视图用法实例分析
Nov 07 Python
使用python远程操作linux过程解析
Dec 04 Python
解决Django部署设置Debug=False时xadmin后台管理系统样式丢失
Apr 07 Python
Django 解决model 反向引用中的related_name问题
May 19 Python
keras.utils.to_categorical和one hot格式解析
Jul 02 Python
python爬取代理IP并进行有效的IP测试实现
Oct 09 Python
Python django框架 web端视频加密的实例详解
Nov 20 Python
tensorflow 模型权重导出实例
Jan 24 #Python
在Tensorflow中查看权重的实现
Jan 24 #Python
tensorflow求导和梯度计算实例
Jan 23 #Python
Tensorflow的梯度异步更新示例
Jan 23 #Python
在Tensorflow中实现梯度下降法更新参数值
Jan 23 #Python
Tensorflow实现部分参数梯度更新操作
Jan 23 #Python
将tensorflow模型打包成PB文件及PB文件读取方式
Jan 23 #Python
You might like
php冒泡排序、快速排序、快速查找、二维数组去重实例分享
2014/04/24 PHP
qq登录,新浪微博登录接口申请过程中遇到的问题
2014/07/22 PHP
Zend Framework入门知识点小结
2016/03/19 PHP
js鼠标滑过弹出层的定位IE6bug解决办法
2012/12/26 Javascript
javascript中的undefined和not defined区别示例介绍
2014/02/26 Javascript
理解JavaScript表单的基础知识
2016/01/25 Javascript
javascript html实现网页版日历代码
2016/03/08 Javascript
微信小程序 使用picker封装省市区三级联动实例代码
2016/10/28 Javascript
jQuery动态添加li标签并添加属性和绑定事件方法
2018/02/24 jQuery
JS实现简单获取最近7天和最近3天日期的方法
2018/04/18 Javascript
vue系列之requireJs中引入vue-router的方法
2018/07/18 Javascript
iconfont的三种使用方式详解
2018/08/05 Javascript
浅谈vue同一页面中拥有两个表单时,的验证问题
2018/09/18 Javascript
微信小程序实现顶部下拉菜单栏
2018/11/04 Javascript
Vue侦测相关api的实现方法
2019/05/22 Javascript
jQuery实现图片随机切换、抽奖功能(实例代码)
2019/10/23 jQuery
Chrome插件开发系列一:弹窗终结者开发实战
2020/10/02 Javascript
[02:29]DOTA2英雄基础教程 陈
2013/12/17 DOTA
[01:25]2014DOTA2国际邀请赛 zhou分析LGD比赛情况
2014/07/14 DOTA
Django URL传递参数的方法总结
2016/08/28 Python
python 读写文件,按行修改文件的方法
2018/07/12 Python
numpy中loadtxt 的用法详解
2018/08/03 Python
Python字符串通过'+'和join函数拼接新字符串的性能测试比较
2019/03/05 Python
Python/Django后端使用PIL Image生成头像缩略图
2019/04/30 Python
Python 获取 datax 执行结果保存到数据库的方法
2019/07/11 Python
django rest framework 过滤时间操作
2020/07/12 Python
娇韵诗俄罗斯官方网站:Clarins俄罗斯
2020/10/03 全球购物
《胖乎乎的小手》教学反思
2014/02/26 职场文书
环保志愿者活动方案
2014/08/14 职场文书
小学生三分钟演讲稿
2014/08/18 职场文书
锦旗赠语
2015/06/23 职场文书
婚宴祝酒词大全
2015/08/10 职场文书
商业计划书格式、范文
2019/03/21 职场文书
手把手教你从零开始react+antd搭建项目
2021/06/03 Javascript
使用python求解迷宫问题的三种实现方法
2022/03/17 Python
Nginx配置之禁止指定IP访问
2022/05/02 Servers