TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的__init__ 、__new__、__call__小结
Apr 25 Python
python多重继承新算法C3介绍
Sep 28 Python
让 python 命令行也可以自动补全
Nov 30 Python
Python获取任意xml节点值的方法
May 05 Python
Django Admin实现上传图片校验功能
Mar 06 Python
使用Eclipse如何开发python脚本
Apr 11 Python
python并发和异步编程实例
Nov 15 Python
详解opencv Python特征检测及K-最近邻匹配
Jan 21 Python
python ChainMap的使用和说明详解
Jun 11 Python
将pytorch转成longtensor的简单方法
Feb 18 Python
Django项目创建及管理实现流程详解
Oct 13 Python
opencv-python图像配准(匹配和叠加)的实现
Jun 23 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
Adodb的十个实例(清晰版)
2006/12/31 PHP
PHP 如何获取二维数组中某个key的集合
2014/06/03 PHP
PHP实现把文本中的URL转换为链接的auolink()函数分享
2014/07/29 PHP
PHP实现模仿socket请求返回页面的方法
2014/11/04 PHP
JS 模态对话框和非模态对话框操作技巧汇总
2013/04/15 Javascript
JS关键字球状旋转效果的实例代码
2013/11/29 Javascript
window.location.href的用法(动态输出跳转)
2014/08/09 Javascript
jQuery过滤特殊字符及JS字符串转为数字
2016/05/26 Javascript
jQuery快速实现商品数量加减的方法
2017/02/06 Javascript
angularjs实现table增加tr的方法
2018/02/27 Javascript
layui select动态添加option的实例
2018/03/07 Javascript
JS模拟实现哈希表及应用详解
2018/05/04 Javascript
vue-cli项目中使用公用的提示弹层tips或加载loading组件实例详解
2018/05/28 Javascript
vue.draggable实现表格拖拽排序效果
2018/12/01 Javascript
jQuery+css实现的点击图片放大缩小预览功能示例【图片预览 查看大图】
2020/05/29 jQuery
[02:53]DOTA2英雄基础教程 山岭巨人小小
2013/12/09 DOTA
python计算N天之后日期的方法
2015/03/31 Python
在pycharm中python切换解释器失败的解决方法
2018/10/29 Python
python集合是否可变总结
2019/06/20 Python
如何使用python实现模拟鼠标点击
2020/01/06 Python
Python3 利用face_recognition实现人脸识别的方法
2020/03/13 Python
jupyter notebook 多行输出实例
2020/04/09 Python
Pycharm常用快捷键总结及配置方法
2020/11/14 Python
基础的CSS3弹性盒Flexbox布局使用实例
2016/04/08 HTML / CSS
CSS3 实现的加载动画
2020/12/07 HTML / CSS
电信营业员自我评价分享
2014/01/17 职场文书
仓库文员岗位职责
2014/04/06 职场文书
拾金不昧锦旗标语
2014/06/27 职场文书
车辆年检委托书范本
2014/10/14 职场文书
群众路线自我剖析及整改措施
2014/11/04 职场文书
2015年“七七卢沟桥事变”纪念活动总结
2015/03/24 职场文书
小学毕业教师寄语
2019/06/21 职场文书
2019个人工作计划书的格式及范文!
2019/07/04 职场文书
详解Java分布式事务的 6 种解决方案
2021/06/26 Java/Android
仅仅使用 HTML/CSS 实现各类进度条的方式汇总
2021/11/11 HTML / CSS
vue实现移动端div拖动效果
2022/03/03 Vue.js