TensorFlow saver指定变量的存取


Posted in Python onMarch 10, 2018

今天和大家分享一下用TensorFlow的saver存取训练好的模型那点事。

1. 用saver存取变量;
2. 用saver存取指定变量。

用saver存取变量。

话不多说,先上代码

# coding=utf-8
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
saver =tf.train.Saver()
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)

这里我随便定义了几个变量然后进行存操作,运行后,变量w,b,s会被保存下来。保存会生成如下几个文件:

  • cheakpoint
  • save_net.ckpt.data-*
  • save_net.ckpt.index
  • save_net.ckpt.meta

接下来是读取的代码

import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
saver =tf.train.Saver()
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 print ("s",sess.run(a))

在写读取代码时要注意变量定义的类型、大小和变量的数量以及顺序等要与存的时候一致,不然会报错。你存的时候顺序是w,b,s,取的时候同样这个顺序。存的时候w定义了dtype没有 定义name,取的时候同样要这样,因为TensorFlow存取是按照键值对来存取的,所以必须一致。这里变量名,也就是w,s之类可以不同。

如下是我成功读取的效果

TensorFlow saver指定变量的存取

用saver存取指定变量。

在我们做训练时候,有些变量是没有必要保存的,但是如果直接用tf.train.Saver()。程序会将所有的变量保存下来,这时候我们可以指定保存,只保存我们需要的变量,其他的统统丢掉。
其实很简单,只需要在上面代码基础上稍加修改,只需把tf.train.Saver()替换成如下代码

program = []
program += [w,b]
tf.train.Saver(program)

这样,程序就只会存w和b了。同样,读取程序里面的tf.train.Saver()也要做如上修改。dtype,name之类依旧必须一致。

最后附上最终代码:

# coding=utf-8
# saver保存变量测试
import os        
import tensorflow as tf
import numpy
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #有些指令集没有装,加这个不显示那些警告
w = tf.Variable([[1,2,3],[2,3,4],[6,7,8]],dtype=tf.float32)
b = tf.Variable([[4,5,6]],dtype=tf.float32,)
s = tf.Variable([[2, 5],[5, 6]], dtype=tf.float32)
init = tf.global_variables_initializer()
program = []
program += [w, b]
saver =tf.train.Saver(program)
with tf.Session() as sess:
 sess.run(init)
 save_path = saver.save(sess, "save_net.ckpt")#路径可以自己定
 print("save to path:",save_path)
#saver提取变量测试
import tensorflow as tf
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

w = tf.Variable(np.arange(9).reshape((3,3)),dtype=tf.float32)
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32)
a = tf.Variable(np.arange(4).reshape((2,2)),dtype=tf.float32)
program = []
program +=[w,b]
saver =tf.train.Saver(program)
with tf.Session() as sess:

 saver.restore(sess,'save_net.ckpt')
 print ("weights",sess.run(w))
 print ("b",sess.run(b))
 #print ("s",sess.run(a))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python不规范的日期字符串处理类
Jun 10 Python
Python 获取新浪微博的最新公共微博实例分享
Jul 03 Python
一个计算身份证号码校验位的Python小程序
Aug 15 Python
简单分析Python中用fork()函数生成的子进程
May 04 Python
Python中顺序表的实现简单代码分享
Jan 09 Python
Python3 利用requests 库进行post携带账号密码请求数据的方法
Oct 26 Python
Python中py文件转换成exe可执行文件的方法
Jun 14 Python
使用python实现ftp的文件读写方法
Jul 02 Python
Python代码太长换行的实现
Jul 05 Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 Python
python -v 报错问题的解决方法
Sep 15 Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
Python使用pyh生成HTML文档的方法示例
Mar 10 #Python
tensorflow获取变量维度信息
Mar 10 #Python
TensorFlow变量管理详解
Mar 10 #Python
You might like
微信支付的开发流程详解
2016/09/13 PHP
thinkphp5.0整合phpsocketio完整攻略(绕坑)
2018/10/12 PHP
laravel5实现微信第三方登录功能
2018/12/06 PHP
PHP匿名函数(闭包函数)详解
2019/03/22 PHP
PHP设计模式(五)适配器模式Adapter实例详解【结构型】
2020/05/02 PHP
获取URL地址中的文件名和参数的javascript代码
2009/09/02 Javascript
javascript 添加和移除函数的通用方法
2009/10/20 Javascript
jQuery chili图片远处放大插件
2009/11/30 Javascript
jQuery操作JSON的CRUD用法实例
2015/02/25 Javascript
jQuery使用post方法提交数据实例
2015/03/25 Javascript
JavaScript如何获取数组最大值和最小值
2015/11/18 Javascript
Javascript原型链的原理详解
2016/01/05 Javascript
AngularJS整合Springmvc、Spring、Mybatis搭建开发环境
2016/02/25 Javascript
利用JQuery写一个简单的异步分页插件
2016/03/07 Javascript
Ajax分页插件Pagination从前台jQuery到后端java总结
2016/07/22 Javascript
js替换字符串中所有指定的字符(实现代码)
2016/08/17 Javascript
BootStrap实现手机端轮播图左右滑动事件
2016/10/13 Javascript
第一次接触神奇的前端框架vue.js
2016/12/01 Javascript
最常见和最有用的字符串相关的方法详解
2017/02/06 Javascript
Python实现Pig Latin小游戏实例代码
2018/02/02 Python
python画一个玫瑰和一个爱心
2020/08/18 Python
Python设计模式之代理模式实例详解
2019/01/19 Python
计算机二级python学习教程(1) 教大家如何学习python
2019/05/16 Python
解决pyecharts在jupyter notebook中使用报错问题
2020/04/23 Python
python利用datetime模块计算程序运行时间问题
2020/02/20 Python
浅谈在django中使用redirect重定向数据传输的问题
2020/03/13 Python
Python Pandas 对列/行进行选择,增加,删除操作
2020/05/17 Python
keras K.function获取某层的输出操作
2020/06/29 Python
python实现测试工具(二)——简单的ui测试工具
2020/10/19 Python
上课迟到检讨书
2014/01/19 职场文书
幼儿园元旦亲子活动方案
2014/02/17 职场文书
新年爱情寄语
2014/04/08 职场文书
2015年医院工作总结范文
2015/04/09 职场文书
员工手册董事长致辞
2015/07/29 职场文书
教师学习十八届五中全会精神心得体会
2016/01/05 职场文书
python 爬取天气网卫星图片
2021/06/07 Python