TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python监控网卡流量并使用graphite绘图的示例
Apr 27 Python
跟老齐学Python之玩转字符串(2)更新篇
Sep 28 Python
Python获取服务器信息的最简单实现方法
Mar 05 Python
Python3遍历目录树实现方法
May 22 Python
python 分离文件名和路径以及分离文件名和后缀的方法
Oct 21 Python
python实现弹窗祝福效果
Apr 07 Python
Pytorch.nn.conv2d 过程验证方式(单,多通道卷积过程)
Jan 03 Python
使用Python打造一款间谍程序的流程分析
Feb 21 Python
Python实现播放和录制声音的功能
Aug 12 Python
通过实例解析python subprocess模块原理及用法
Oct 10 Python
opencv+pyQt5实现图片阈值编辑器/寻色块阈值利器
Nov 13 Python
基于Python实现射击小游戏的制作
Apr 06 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
php解析http获取的json字符串变量总是空白null
2015/03/02 PHP
微信第三方登录(原生)demo【必看篇】
2017/05/26 PHP
PHP实现的观察者模式实例
2017/06/21 PHP
PHP高精确度运算BC函数库实例详解
2017/08/15 PHP
PDO::quote讲解
2019/01/29 PHP
PHP上传图片到数据库并显示的实例代码
2019/12/20 PHP
奇妙的js
2007/09/24 Javascript
通过jquery还原含有rowspan、colspan的table的实现方法
2012/02/10 Javascript
js 使用form表单select类实现级联菜单效果
2012/12/19 Javascript
IE的事件传递-event.cancelBubble示例介绍
2014/01/12 Javascript
jQuery使用之处理页面元素用法实例
2015/01/19 Javascript
jQuery树形控件zTree使用小结
2016/08/02 Javascript
javascript iframe跨域详解
2016/10/26 Javascript
Ajax跨域实现代码(后台jsp)
2017/01/21 Javascript
JavaScript Uploadify文件上传实例
2017/02/28 Javascript
js判断是否是手机页面
2017/03/17 Javascript
详解基于Node.js的微信JS-SDK后端接口实现代码
2017/07/15 Javascript
基于layui实现高级搜索(筛选)功能
2019/07/26 Javascript
Node.js文本文件BOM头的去除方法
2020/11/22 Javascript
[03:58]兄弟们,回来开黑了!DOTA2昔日战友招募宣传视频
2016/07/17 DOTA
python正则表达式match和search用法实例
2015/03/26 Python
python实现的希尔排序算法实例
2015/07/01 Python
python爬虫神器Pyppeteer入门及使用
2019/07/13 Python
python:按行读入,排序然后输出的方法
2019/07/20 Python
PyCharm Anaconda配置PyQt5开发环境及创建项目的教程详解
2020/03/24 Python
python中Pexpect的工作流程实例讲解
2021/03/02 Python
美国运动鞋和运动服零售商:Footaction
2017/04/07 全球购物
环保专业大学生职业规划设计
2014/01/10 职场文书
教师敬业奉献模范事迹材料
2014/05/18 职场文书
优秀家长自荐材料
2014/08/26 职场文书
判缓刑人员个人思想汇报
2014/10/10 职场文书
趣味运动会赞词
2015/07/22 职场文书
创业计划之特色精品店
2019/08/12 职场文书
Python标准库之typing的用法(类型标注)
2021/06/02 Python
浅谈sql_@SelectProvider及使用注意说明
2021/08/04 Java/Android
Windows Server 2008配置防火墙策略详解
2022/06/28 Servers