TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中threading超线程用法实例分析
May 16 Python
python简单获取数组元素个数的方法
Jul 13 Python
解决pyqt中ui编译成窗体.py中文乱码的问题
Dec 23 Python
Python正则替换字符串函数re.sub用法示例
Jan 19 Python
Python实现霍夫圆和椭圆变换代码详解
Jan 12 Python
对Python Pexpect 模块的使用说明详解
Feb 14 Python
使用 Python 玩转 GitHub 的贡献板(推荐)
Apr 04 Python
python生成随机红包的实例写法
Sep 02 Python
关于numpy中eye和identity的区别详解
Nov 29 Python
在Keras中实现保存和加载权重及模型结构
Jun 15 Python
基于Django快速集成Echarts代码示例
Dec 01 Python
django中ImageField的使用详解
Dec 21 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
ubuntu 编译安装php 5.3.3+memcache的方法
2010/08/05 PHP
php获取服务器信息的实现代码
2013/02/04 PHP
详解PHP的Yii框架中日志的相关配置及使用
2015/12/08 PHP
PHP判断是否是微信打开还是浏览器打开的方法
2019/02/27 PHP
从sohu弄下来的flash中展示图片的代码
2007/04/27 Javascript
javascript getElementsByName()的用法说明
2009/07/31 Javascript
javascript学习笔记(十二) RegExp类型介绍
2012/06/20 Javascript
JavaScript高级程序设计(第3版)学习笔记5 js语句
2012/10/11 Javascript
css样式标签和js语法属性区别
2013/11/06 Javascript
jquery validate添加自定义验证规则(验证邮箱 邮政编码)
2013/12/04 Javascript
js中同步与异步处理的方法和区别总结
2013/12/25 Javascript
jQuery设置与获取HTML,文本和值的简单实例
2014/02/26 Javascript
js中数组排序sort方法的原理分析
2014/11/20 Javascript
javascript中基本类型和引用类型的区别分析
2015/05/12 Javascript
AngularJS中$interval的用法详解
2016/02/02 Javascript
js 输入框 正则表达式(菜鸟必看教程)
2017/02/19 Javascript
JavaScript轮播图简单制作方法
2017/02/20 Javascript
JavaScript根据json生成html表格的示例代码
2018/10/24 Javascript
微信小程序--获取用户地理位置名称(无须用户授权)的方法
2019/04/29 Javascript
在微信小程序中使用vant的方法
2019/06/07 Javascript
JavaScript实现省市联动效果
2019/11/22 Javascript
python模拟登录百度贴吧(百度贴吧登录)实例
2013/12/18 Python
python中的yield使用方法
2014/02/11 Python
Python cookbook(数据结构与算法)将序列分解为单独变量的方法
2018/02/13 Python
python实现顺时针打印矩阵
2019/03/02 Python
Python socket连接中的粘包、精确传输问题实例分析
2020/03/24 Python
python 在threading中如何处理主进程和子线程的关系
2020/04/25 Python
Converse匡威法国官网:美国著名帆布鞋品牌
2018/12/05 全球购物
物流管理专业大学生自荐信
2013/10/04 职场文书
关于幼儿的自我评价
2013/12/18 职场文书
环境工程专业自荐信范文
2014/03/18 职场文书
师德师风自我评价范文
2014/09/11 职场文书
篮球友谊赛通讯稿
2014/10/10 职场文书
暑假开始了,你的暑假学习计划写好了吗?
2019/07/04 职场文书
MySQL 条件查询的常用操作
2022/04/28 MySQL
Python 一键获取电脑浏览器的账号密码
2022/05/11 Python