TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python下singleton模式的实现方法
Jul 16 Python
python smtplib模块发送SSL/TLS安全邮件实例
Apr 08 Python
在NumPy中创建空数组/矩阵的方法
Jun 15 Python
python实现监控某个服务 服务崩溃即发送邮件报告
Jun 21 Python
Python统计python文件中代码,注释及空白对应的行数示例【测试可用】
Jul 25 Python
浅谈pandas用groupby后对层级索引levels的处理方法
Nov 06 Python
简单了解python 邮件模块的使用方法
Jul 24 Python
Python实现银行账户资金交易管理系统
Jan 03 Python
tensorflow 大于某个值为1,小于为0的实例
Jun 30 Python
Python延迟绑定问题原理及解决方案
Aug 04 Python
4种非常实用的python内置数据结构
Apr 28 Python
Python Matplotlib绘制两个Y轴图像
Apr 13 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
索尼SONY ICF-7600A(W)电路分析
2021/03/01 无线电
PHP系统流量分析的程序
2006/10/09 PHP
教你如何在CI框架中使用 .htaccess 隐藏url中index.php
2014/06/09 PHP
PHP如何将XML转成数组
2016/04/04 PHP
PHP 图片处理
2020/09/16 PHP
HTML Dom与Css控制方法
2010/10/25 Javascript
jQuery的链式调用浅析
2010/12/03 Javascript
JS防止用户多次提交的简单代码
2013/08/01 Javascript
js形成页面的一种遮罩效果实例代码
2014/01/04 Javascript
原生js仿jq判断当前浏览器是否为ie,精确到ie6~8
2014/08/30 Javascript
javascript 常用验证函数总结
2016/06/28 Javascript
JavaScript事件用法浅析
2016/10/31 Javascript
js+html5生成自动排列对话框实例
2017/10/09 Javascript
vscode下vue项目中eslint的使用方法
2019/01/13 Javascript
详解Vue.js和layui日期控件冲突问题解决办法
2019/07/25 Javascript
微信小程序下拉加载和上拉刷新两种实现方法详解
2019/09/05 Javascript
JS操作json对象key、value的常用方法分析
2019/10/29 Javascript
js实现轮播图效果 z-index实现轮播图
2020/01/17 Javascript
[06:37]2014DOTA2国际邀请赛 昔日王者渴望重回巅峰
2014/07/12 DOTA
[43:24]2018DOTA2亚洲邀请赛3月29日 小组赛A组 LGD VS Liquid
2018/03/30 DOTA
python将html转成PDF的实现代码(包含中文)
2013/03/04 Python
解析Python编程中的包结构
2015/10/25 Python
Python实现的科学计算器功能示例
2017/08/04 Python
利用Python将时间或时间间隔转为ISO 8601格式方法示例
2017/09/05 Python
从头学Python之编写可执行的.py文件
2017/11/28 Python
[原创]教女朋友学Python(一)运行环境搭建
2017/11/29 Python
python多进程读图提取特征存npy
2019/05/21 Python
Python Pandas数据中对时间的操作
2019/07/30 Python
python json.dumps中文乱码问题解决
2020/04/01 Python
PyChon中关于Jekins的详细安装(推荐)
2020/12/28 Python
python中操作文件的模块的方法总结
2021/02/04 Python
捐助倡议书范文
2014/04/15 职场文书
交通事故赔偿协议书怎么写
2014/10/04 职场文书
公司开会通知
2015/04/20 职场文书
2019毕业典礼主持词!
2019/07/05 职场文书
QT连接MYSQL数据库的详细步骤
2021/07/07 MySQL