TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python实现正则匹配检索远端FTP目录下的文件
Mar 25 Python
Python如何为图片添加水印
Nov 25 Python
python 实现将字典dict、列表list中的中文正常显示方法
Jul 06 Python
Python实现的多进程拷贝文件并显示百分比功能示例
Apr 09 Python
使用Python实现跳帧截取视频帧
May 31 Python
Python完成毫秒级抢淘宝大单功能
Jun 06 Python
Python中拆分字符串的操作方法
Jul 23 Python
django基于restframework的CBV封装详解
Aug 08 Python
命令行运行Python脚本时传入参数的三种方式详解
Oct 11 Python
Python调用OpenCV实现图像平滑代码实例
Jun 19 Python
Python 程序报错崩溃后如何倒回到崩溃的位置(推荐)
Jun 23 Python
Python通过yagmail实现发送邮件代码解析
Oct 27 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
php array_map array_multisort 高效处理多维数组排序
2009/06/11 PHP
基于php实现的php代码加密解密类完整实例
2016/10/12 PHP
Ajax中的JSON格式与php传输过程全面解析
2017/11/14 PHP
PHP unlink与rmdir删除目录及目录下所有文件实例代码
2018/02/07 PHP
PHP删除数组中特定元素的两种方法
2019/02/28 PHP
Javascript Function对象扩展之延时执行函数
2010/07/06 Javascript
JQuery判断HTML元素是否存在的两种解决方法
2013/12/26 Javascript
jQuery 浮动导航菜单适合购物商品类型的网站
2014/09/09 Javascript
吐槽一下我所了解的Node.js
2014/10/08 Javascript
jQuery中element选择器用法实例
2014/12/29 Javascript
JS+CSS实现带小三角指引的滑动门效果
2015/09/22 Javascript
jQuery自定义组件(导入组件)
2016/11/08 Javascript
Easyui笔记2:实现datagrid多行删除的示例代码
2017/01/14 Javascript
angular实现商品筛选功能
2017/02/01 Javascript
Zepto实现密码的隐藏/显示
2017/04/07 Javascript
jQuery实现动态生成表格并为行绑定单击变色动作的方法
2017/04/17 jQuery
vue实现页面滚动到底部刷新
2019/08/16 Javascript
原生JS实现留言板
2020/03/26 Javascript
vue的$http的get请求要加上params操作
2020/11/12 Javascript
python安装mysql-python简明笔记(ubuntu环境)
2016/06/25 Python
python使用json序列化datetime类型实例解析
2018/02/11 Python
浅谈python之新式类
2018/08/12 Python
Django配置MySQL数据库的完整步骤
2019/09/07 Python
Python的Tqdm模块实现进度条配置
2021/02/24 Python
python连接手机自动搜集蚂蚁森林能量的实现代码
2021/02/24 Python
GetYourGuide台湾:预订旅游活动、景点和旅游项目
2019/06/10 全球购物
英语自我评价范文
2014/01/24 职场文书
小区消防演习方案
2014/02/21 职场文书
小学生运动会报道稿
2014/09/12 职场文书
农村环境卫生倡议书
2015/04/29 职场文书
运动会班级口号霸气押韵
2015/12/24 职场文书
python函数指定默认值的实例讲解
2021/03/29 Python
pytorch 一行代码查看网络参数总量的实现
2021/05/12 Python
用Python进行栅格数据的分区统计和批量提取
2021/05/27 Python
《艾尔登法环》1.03.3补丁上线 碎星伤害调整
2022/04/06 其他游戏
Android Flutter实现图片滑动切换效果
2022/04/07 Java/Android