TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python编写微信远程控制电脑的程序
Jan 05 Python
python读取文本中数据并转化为DataFrame的实例
Apr 10 Python
Python基于递归算法实现的汉诺塔与Fibonacci数列示例
Apr 18 Python
PyQt5实现无边框窗口的标题拖动和窗口缩放
Apr 19 Python
django静态文件加载的方法
May 20 Python
python2 与 pyhton3的输入语句写法小结
Sep 10 Python
如何基于python生成list的所有的子集
Nov 11 Python
Django models filter筛选条件详解
Mar 16 Python
Python count函数使用方法实例解析
Mar 23 Python
Keras loss函数剖析
Jul 06 Python
Django 实现图片上传和下载功能
Dec 31 Python
python 算法题——快乐数的多种解法
May 27 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
php 信息采集程序代码
2009/03/17 PHP
PHP性能优化准备篇图解PEAR安装
2011/12/05 PHP
Yii框架参数化查询中IN查询只能查询一个的解决方法
2017/05/20 PHP
Laravel框架实现的记录SQL日志功能示例
2018/06/19 PHP
jQuery与ExtJS之选择实例分析
2010/08/19 Javascript
JavaScript高级程序设计 扩展--关于动态原型
2010/11/09 Javascript
ajax请求get与post的区别总结
2013/11/04 Javascript
javascript回车完美实现tab切换功能
2014/03/13 Javascript
jquery.validate.js插件使用经验记录
2014/07/02 Javascript
使用Ajax生成的Excel文件并下载的实例
2016/11/21 Javascript
微信小程序 数据交互与渲染实例详解
2017/01/21 Javascript
Node.js实现连接mysql数据库功能示例
2017/09/15 Javascript
微信小程序自定义prompt组件步骤详解
2018/06/12 Javascript
在Vue中使用axios请求拦截的实现方法
2018/10/25 Javascript
详解js实时获取并显示当前时间的方法
2019/05/10 Javascript
js 判断当前时间是否处于某个一个时间段内
2019/09/19 Javascript
vue 路由子组件created和mounted不起作用的解决方法
2019/11/05 Javascript
解决VUE mounted 钩子函数执行时 img 未加载导致页面布局的问题
2020/07/27 Javascript
js实现3D旋转相册
2020/08/02 Javascript
[01:30:54]《加油DOTA》 第三期
2014/08/18 DOTA
使用Protocol Buffers的C语言拓展提速Python程序的示例
2015/04/16 Python
微信跳一跳python辅助软件思路及图像识别源码解析
2018/01/04 Python
python requests post多层字典的方法
2018/12/27 Python
python3.6数独问题的解决
2019/01/21 Python
Python获取一个用户名的组ID过程解析
2019/09/03 Python
春节到了 教你使用python来抢票回家
2020/01/06 Python
Python 字符串处理特殊空格\xc2\xa0\t\n Non-breaking space
2020/02/23 Python
Python常驻任务实现接收外界参数代码解析
2020/07/21 Python
Python 实现PS滤镜的旋涡特效
2020/12/03 Python
PyCharm Ctrl+Shift+F 失灵的简单有效解决操作
2021/01/15 Python
CSS3,线性渐变(linear-gradient)的使用总结
2017/01/09 HTML / CSS
2015年感恩节活动总结
2015/03/24 职场文书
离婚代理词范文
2015/05/23 职场文书
行政处罚告知书
2015/07/01 职场文书
运动会班级前导词
2015/07/20 职场文书
利用html+css实现菜单栏缓慢下拉效果的示例代码
2021/03/30 HTML / CSS