TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python任务调度实例分析
May 19 Python
Python的条件语句与运算符优先级详解
Oct 13 Python
python对离散变量的one-hot编码方法
Jul 11 Python
python pandas消除空值和空格以及 Nan数据替换方法
Oct 30 Python
django小技巧之html模板中调用对象属性或对象的方法
Nov 30 Python
想学python 这5本书籍你必看!
Dec 11 Python
python使用suds调用webservice接口的方法
Jan 03 Python
python学习--使用QQ邮箱发送邮件代码实例
Apr 16 Python
python将excel转换为csv的代码方法总结
Jul 03 Python
python3 sorted 如何实现自定义排序标准
Mar 12 Python
学习Python爬虫的几点建议
Aug 05 Python
python数据分析之用sklearn预测糖尿病
Apr 22 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
PHP数字字符串左侧补0、字符串填充和自动补齐的几种方法
2014/05/10 PHP
大家在抢红包,程序员在研究红包算法
2015/08/31 PHP
JavaScript与HTML结合的基本使用方法整理
2015/10/12 PHP
php 开发中加密的几种方法总结
2017/03/22 PHP
TP5(thinkPHP5)框架使用ajax实现与后台数据交互的方法小结
2020/02/10 PHP
jquery实现固定顶部导航效果(仿蘑菇街)
2013/03/21 Javascript
在js文件中写el表达式取不到值的原因及解决方法
2013/12/23 Javascript
jQuery实现的自定义滚动条实例详解
2016/09/20 Javascript
浅谈React 属性和状态的一些总结
2016/11/21 Javascript
浅析JavaScript中break、continue和return的区别
2016/11/30 Javascript
BootStrap Validator对于隐藏域验证和程序赋值即时验证的问题浅析
2016/12/01 Javascript
js中this对象用法分析
2018/01/05 Javascript
Js通过AES加密后PHP用Openssl解密的方法
2019/07/12 Javascript
vue.js购物车添加商品组件的方法
2019/09/17 Javascript
js页面加载后执行的几种方式小结
2020/01/30 Javascript
js、jquery实现列表模糊搜索功能过程解析
2020/03/27 jQuery
[54:43]DOTA2-DPC中国联赛 正赛 CDEC vs Dynasty BO3 第一场 2月22日
2021/03/11 DOTA
Python threading多线程编程实例
2014/09/18 Python
Python列表(list)常用操作方法小结
2015/02/02 Python
Python3写入文件常用方法实例分析
2015/05/22 Python
Python实现八大排序算法
2016/08/13 Python
Python 3中的yield from语法详解
2017/01/18 Python
python3+PyQt5实现使用剪贴板做复制与粘帖示例
2017/01/24 Python
python实现12306抢票及自动邮件发送提醒付款功能
2018/03/08 Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
2019/08/17 Python
python数值基础知识浅析
2019/11/19 Python
Selenium启动Chrome时配置选项详解
2020/03/18 Python
使用python将微信image下.dat文件解密为.png的方法
2020/11/30 Python
python如何构建mock接口服务
2021/01/28 Python
阿拉伯世界最大的电子卖场:Souq埃及
2016/08/01 全球购物
竞选文艺委员演讲稿
2014/04/28 职场文书
中药学专业毕业生推荐信
2014/07/10 职场文书
党的群众路线教育实践活动个人对照检查材料(校长)
2014/11/05 职场文书
计算机专业自荐信范文
2015/03/26 职场文书
2015年校本培训工作总结
2015/07/24 职场文书
mybatis-plus模糊查询指定字段
2022/04/28 Java/Android