TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python笔记(1) 关于我们应不应该继续学习python
Oct 24 Python
python获取当前时间对应unix时间戳的方法
May 15 Python
利用pyinstaller或virtualenv将python程序打包详解
Mar 22 Python
python实现按长宽比缩放图片
Jun 07 Python
对Python 检查文件名是否规范的实例详解
Jun 10 Python
python使用writerows写csv文件产生多余空行的处理方法
Aug 01 Python
Python3操作Excel文件(读写)的简单实例
Sep 02 Python
Python内置方法实现字符串的秘钥加解密(推荐)
Dec 09 Python
Python进阶之迭代器与迭代器切片教程
Jan 29 Python
tensorflow实现训练变量checkpoint的保存与读取
Feb 10 Python
Python数据结构dict常用操作代码实例
Mar 12 Python
Django 解决阿里云部署同步数据库报错的问题
May 14 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
关于查看MSSQL 数据库 用户每个表 占用的空间大小
2013/06/21 PHP
PHP仿博客园 个人博客(1) 数据库与界面设计
2013/07/05 PHP
php实现两个数组相加的方法
2015/02/17 PHP
php简单备份与还原MySql的方法
2016/05/09 PHP
基于JQuery 的消息提示框效果代码
2011/07/31 Javascript
基于jquery封装的一个js分页
2011/11/15 Javascript
Jquery提交表单 Form.js官方插件介绍
2012/03/01 Javascript
基于jquery的跟随屏幕滚动代码
2012/07/24 Javascript
兼容IE、firefox以及chrome的js获取时间(getFullYear)
2014/07/04 Javascript
node.js中的fs.openSync方法使用说明
2014/12/17 Javascript
jQuery判断一个元素是否可见的方法
2015/06/05 Javascript
JavaScript实现上下浮动的窗口效果代码
2015/10/12 Javascript
jQuery.Callbacks()回调函数队列用法详解
2016/06/14 Javascript
使用angular帮你实现拖拽的示例
2017/07/05 Javascript
Vue项目引进ElementUI组件的方法
2018/11/11 Javascript
基于JS正则表达式实现模板数据动态渲染(实现思路详解)
2020/03/07 Javascript
JS如何在数组指定位置插入元素
2020/03/10 Javascript
Node.js path模块,获取文件后缀名操作
2020/11/07 Javascript
[46:59]完美世界DOTA2联赛PWL S2 GXR vs Ink 第二场 11.19
2020/11/20 DOTA
python web基础之加载静态文件实例
2018/03/20 Python
Python发送http请求解析返回json的实例
2018/03/26 Python
对Python实现简单的API接口实例讲解
2018/12/10 Python
python学生管理系统开发
2019/01/30 Python
关于tf.nn.dynamic_rnn返回值详解
2020/01/20 Python
python实现密度聚类(模板代码+sklearn代码)
2020/04/27 Python
Python 代码调试技巧示例代码
2020/08/11 Python
材料物理专业大学毕业生求职信
2013/10/15 职场文书
军训拉歌口号
2014/06/13 职场文书
求职教师自荐书
2014/06/19 职场文书
不服从上级领导安排的检讨书
2014/09/14 职场文书
党员剖析材料范文
2014/12/18 职场文书
酒店保洁员岗位职责
2015/02/26 职场文书
法定授权委托证明书
2015/06/18 职场文书
2019入党申请书范文3篇
2019/08/21 职场文书
iPhone13再次曝光
2021/04/15 数码科技
Python趣味挑战之教你用pygame画进度条
2021/05/31 Python