TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程实例简析
Sep 26 Python
Python标准库之sqlite3使用实例
Nov 25 Python
浅谈Python的垃圾回收机制
Dec 17 Python
python 编程之twisted详解及简单实例
Jan 28 Python
Python使用三种方法实现PCA算法
Dec 12 Python
使用Python机器学习降低静态日志噪声
Sep 29 Python
python dataframe向下向上填充,fillna和ffill的方法
Nov 28 Python
python 字符串常用函数详解
Sep 11 Python
Python魔法方法 容器部方法详解
Jan 02 Python
Python+Django+MySQL实现基于Web版的增删改查的示例代码
May 13 Python
Python利用Xpath选择器爬取京东网商品信息
Jun 01 Python
python调用ffmpeg命令行工具便捷操作视频示例实现过程
Nov 01 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
php 引用(&)详解
2009/11/20 PHP
php的常量和变量实例详解
2017/06/27 PHP
Yii框架组件的事件机制原理与用法分析
2020/04/07 PHP
屏蔽鼠标右键、Ctrl+n、shift+F10、F5刷新、退格键 的javascript代码
2007/04/01 Javascript
javascript实现 在光标处插入指定内容
2007/05/25 Javascript
情人节专属 纯js脚本1k大小的3D玫瑰效果
2012/02/11 Javascript
javascript匿名函数实例分析
2014/11/18 Javascript
Javascript中的call()方法介绍
2015/03/15 Javascript
js H5 canvas投篮小游戏
2016/08/18 Javascript
node.js 和HTML5开发本地桌面应用程序
2016/12/13 Javascript
react-native 完整实现登录功能的示例代码
2017/09/11 Javascript
浅谈Node.js 沙箱环境
2018/05/15 Javascript
Vue.use源码学习小结
2018/06/20 Javascript
详解Js里的for…in和for…of的用法
2019/03/28 Javascript
微信小程序获取用户绑定手机号方法示例
2019/07/21 Javascript
jqGrid表格底部汇总、合计行footerrow处理
2019/08/21 Javascript
Vue3.0中的monorepo管理模式的实现
2019/10/14 Javascript
react组件基本用法示例小结
2020/04/27 Javascript
Vue移动端项目实现使用手机预览调试操作
2020/07/18 Javascript
vue实现路由懒加载的3种方法示例
2020/09/01 Javascript
举例讲解Python中is和id的用法
2015/04/03 Python
Python编程中运用闭包时所需要注意的一些地方
2015/05/02 Python
python使用xlrd与xlwt对excel的读写和格式设定
2017/01/21 Python
python3 http提交json参数并获取返回值的方法
2018/12/19 Python
django drf框架自带的路由及最简化的视图
2019/09/10 Python
python反转列表的三种方式解析
2019/11/08 Python
英国拳击装备购物网站:RDX Sports
2018/01/23 全球购物
美国在线奢侈品寄售商店:Luxury Garage Sale
2018/08/19 全球购物
英国乐购杂货:Tesco Groceries
2018/11/29 全球购物
什么叫应用程序域?什么是受管制的代码?什么是强类型系统?什么是装箱和拆箱?
2016/08/13 面试题
韩语专业本科生求职信
2013/10/01 职场文书
英文版网络工程师求职信
2013/10/28 职场文书
2014年三八妇女节活动总结
2014/03/01 职场文书
大学生赌博检讨书
2014/09/22 职场文书
2015年高一班主任工作总结
2015/05/13 职场文书
go goroutine 怎样进行错误处理
2021/07/16 Golang