解决tensorflow模型参数保存和加载的问题


Posted in Python onJuly 26, 2018

终于找到bug原因!记一下;还是不熟悉平台的原因造成的!

Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误? 而 直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错?

model.py,里面含有 ModelV 和 ModelP,另外还有 modelP.py 和 modelV.py 分别只含有 ModelP 和 ModeV 这两个对象,先使用 modelP.py 和 modelV.py 分别训练好模型,然后再在 model.py 里加载进来:

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

 def __init__(self):

  self.v1 = tf.Variable(66, name="v1")
  self.v2 = tf.Variable(77, name="v2")
  self.save_path = "model_v/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.saver = tf.train.Saver()
  self.sess = tf.Session()

 def train(self):
  self.sess.run(self.init)
  print 'v2', self.v2.eval(self.sess)

  self.saver.save(self.sess, self.save_path)
  print "ModelV saved."

 def predict(self):

  all_vars = tf.trainable_variables()
  for v in all_vars:
   print(v.name)
  self.saver.restore(self.sess, self.save_path)
  print "ModelV restored."
  print 'v2', self.v2.eval(self.sess)
  print '------------------------------------------------------------------'

class ModelP():

 def __init__(self):

  self.p1 = tf.Variable(88, name="p1")
  self.p2 = tf.Variable(99, name="p2")
  self.save_path = "model_p/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.saver = tf.train.Saver()
  self.sess = tf.Session()

 def train(self):
  self.sess.run(self.init)
  print 'p2', self.p2.eval(self.sess)

  self.saver.save(self.sess, self.save_path)
  print "ModelP saved."

 def predict(self):

  all_vars = tf.trainable_variables()
  for v in all_vars:
   print v.name
  self.saver.restore(self.sess, self.save_path)
  print "ModelP restored."
  print 'p2', self.p2.eval(self.sess)
  print '---------------------------------------------------------------------'


if __name__ == '__main__':
 v = ModelV()
 p = ModelP()
 v.predict()
 #v.train()
 p.predict() 
 #p.train()

这里 tf.global_variables_initializer() 很关键! 尽管你是分别在对象 ModelP 和 ModelV 内部分配和定义的 tf.Variable(),即 v1 v2 和 p1 p2,但是 对 tf 这个模块而言, 这些都是全局变量,可以通过以下代码查看所有的变量,你就会发现同一个文件中同时运行 ModelP 和 ModelV 在初始化之后都打印出了一样的变量,这个是问题的关键所在:

all_vars = tf.trainable_variables()
for v in all_vars:
 print(v.name)

错误。你可以交换 modelP 和 modelV 初始化的顺序,看看错误信息的变化

v1:0
v2:0
p1:0
p2:0
ModelV restored.
v2 77
v1:0
v2:0
p1:0
p2:0
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v2 not found in checkpoint
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v1 not found in checkpoint

实际上,分开运行时,模型保存的参数是正确的,因为在一个模型里的Variable就只有 v1 v2 或者 p1 p2; 但是在一个文件同时运行的时候,模型参数实际上保存的是 v1 v2 p1 p2四个,因为在默认情况下,创建的Saver,会直接保存所有的参数。而 Saver.restore() 又是默认(无Variable参数列表时)按照已经定义好的全局模型变量来加载对应的参数值, 在进行 ModelV.predict时,按照顺序(从debug可以看出,应该是按照参数顺序一次检测)在模型文件中查找相应的 key,此时能够找到对应的v1 v2,加载成功,但是在 ModelP.predict时,在model_p的模型文件中找不到 v1 和 v2,只有 p1 和 p2, 此时就会报错;不过这里的 第一次加载 还有 p1 p2 找不到没有报错,解释不通, 未完待续

Saver.save() 和 Saver.restore() 是一对, 分别只保存和加载模型的参数, 但是模型的结构怎么知道呢? 必须是你定义好了,而且要和保存的模型匹配才能加载;

如果想要在不定义模型的情况下直接加载出模型结构和模型参数值,使用

# 加载 结构,即 模型参数 变量等
new_saver = tf.train.import_meta_graph("model_v/model.ckpt.meta")
print "ModelV construct"
all_vars = tf.trainable_variables()
for v in all_vars:
 print v.name
 #print v.name,v.eval(self.sess) # v 都还未初始化,不能求值
# 加载模型 参数变量 的 值
new_saver.restore(self.sess, tf.train.latest_checkpoint('model_v/'))
print "ModelV restored."
all_vars = tf.trainable_variables()
for v in all_vars:
 print v.name,v.eval(self.sess)

加载 结构,即 模型参数 变量等完成后,就会有变量了,但是不能访问他的值,因为还未赋值,然后再restore一次即可得到值了

那么上述错误的解决方法就是这个改进版本的model.py;其实 tf.train.Saver 是可以带参数的,他可以保存你想要保存的模型参数,如果不带参数,很可能就会保存 tf.trainable_variables() 所有的variable,而 tf.trainable_variables()又是从 tf 全局得到的,因此只要在模型保存和加载时,构造对应的带参数的tf.train.Saver即可,这样就会保存和加载正确的模型了

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

 def __init__(self):

  self.v1 = tf.Variable(66, name="v1")
  self.v2 = tf.Variable(77, name="v2")
  self.save_path = "model_v/model.ckpt"
  self.init = tf.global_variables_initializer()

  self.sess = tf.Session()

 def train(self):
  saver = tf.train.Saver([self.v1, self.v2])
  self.sess.run(self.init)
  print 'v2', self.v2.eval(self.sess)

  saver.save(self.sess, self.save_path)
  print "ModelV saved."

 def predict(self):
  saver = tf.train.Saver([self.v1, self.v2])
  all_vars = tf.trainable_variables()
  for v in all_vars:
   print v.name

  v_vars = [v for v in all_vars if v.name == 'v1:0' or v.name == 'v2:0']
  print "ModelV restored."
  saver.restore(self.sess, self.save_path)
  for v in v_vars:
   print v.name,v.eval(self.sess) 
  print 'v2', self.v2.eval(self.sess)
  print '------------------------------------------------------------------'

class ModelP():

 def __init__(self):

  self.p1 = tf.Variable(88, name="p1")
  self.p2 = tf.Variable(99, name="p2")
  self.save_path = "model_p/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.sess = tf.Session()

 def train(self):
  saver = tf.train.Saver([self.p1, self.p2])
  self.sess.run(self.init)
  print 'p2', self.p2.eval(self.sess)

  saver.save(self.sess, self.save_path)
  print "ModelP saved."

 def predict(self):
  saver = tf.train.Saver([self.p1, self.p2])
  all_vars = tf.trainable_variables()
  p_vars = [v for v in all_vars if v.name == 'p1:0' or v.name == 'p2:0']
  for v in all_vars:
   print v.name
   #print v.name,v.eval(self.sess)
  saver.restore(self.sess, self.save_path)
  print "ModelP restored."
  for p in p_vars:
   print p.name,p.eval(self.sess)
  print 'p2', self.p2.eval(self.sess)
  print '----------------------------------------------------------'


if __name__ == '__main__':
 v = ModelV()
 p = ModelP()
 v.predict()
 #v.train()
 p.predict() 
 #p.train()

小结: 构造的Saver 最好带Variable参数,这样保证 保存和加载能够正确执行

以上这篇解决tensorflow模型参数保存和加载的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python&MongoDB爬取图书馆借阅记录
Feb 05 Python
解决python ogr shp字段写入中文乱码的问题
Dec 31 Python
Python使用lambda表达式对字典排序操作示例
Jul 25 Python
python tkinter实现屏保程序
Jul 30 Python
Python学习笔记之迭代器和生成器用法实例详解
Aug 08 Python
pygame实现贪吃蛇游戏(上)
Oct 29 Python
Python基础之字符串操作常用函数集合
Feb 09 Python
python实现查找所有程序的安装信息
Feb 18 Python
python应用Axes3D绘图(批量梯度下降算法)
Mar 25 Python
Tensorflow中k.gradients()和tf.stop_gradient()用法说明
Jun 10 Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 Python
Python使用struct处理二进制(pack和unpack用法)
Nov 12 Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
TensorFlow利用saver保存和提取参数的实例
Jul 26 #Python
You might like
$_GET['goods_id']+0 的使用详解
2013/06/06 PHP
php session_start()出错原因分析及解决方法
2013/10/28 PHP
thinkphp验证码显示不出来的解决方法
2014/03/29 PHP
PHP中使用Imagick实现各种图片效果实例
2015/01/21 PHP
详解PHP中instanceof关键字及instanceof关键字有什么作用
2015/11/05 PHP
PHP实现中国公民身份证号码有效性验证示例代码
2017/05/03 PHP
php处理静态页面:页面设置缓存时间实例
2017/06/22 PHP
PHP有序表查找之插值查找算法示例
2018/02/10 PHP
php设计模式之工厂方法模式分析【星际争霸游戏案例】
2020/01/23 PHP
深入理解node exports和module.exports区别
2016/06/01 Javascript
基于JavaScript代码实现自动生成表格
2016/06/15 Javascript
js 上传文件预览的简单实例
2016/08/16 Javascript
AngularJS 中的Promise --- $q服务详解
2016/09/14 Javascript
详解从新建vue项目到引入组件Element的方法
2017/08/29 Javascript
浅谈webpack SplitChunksPlugin实用指南
2018/09/17 Javascript
原生JS实现轮播图效果
2018/10/12 Javascript
layui数据表格跨行自动合并的例子
2019/09/02 Javascript
Python 进程之间共享数据(全局变量)的方法
2019/07/16 Python
双向RNN:bidirectional_dynamic_rnn()函数的使用详解
2020/01/20 Python
Python基于smtplib模块发送邮件代码实例
2020/05/29 Python
Python TestSuite生成测试报告过程解析
2020/07/23 Python
如何利用Python 进行边缘检测
2020/10/14 Python
html5+css3之CSS中的布局与Header的实现
2014/11/21 HTML / CSS
尤为Wconcept中国官网:韩国设计师品牌服饰
2019/01/10 全球购物
Fox Racing官方网站:越野摩托车和山地自行车装备和服装
2019/12/23 全球购物
俄罗斯品牌服装和鞋子在线商店:BRIONITY
2020/03/26 全球购物
介绍一下.net和Java的特点和区别
2012/09/26 面试题
思想汇报格式
2014/01/05 职场文书
优秀员工表扬信
2014/01/17 职场文书
党员违纪检讨书
2014/02/18 职场文书
企业文化标语口号
2014/06/09 职场文书
2015年元旦主持词开场白
2014/12/14 职场文书
中国合伙人观后感
2015/06/02 职场文书
小学生组织委员竞选稿
2015/11/21 职场文书
营销策划分析:怎么策划才能更好销量产品?
2019/09/04 职场文书
Python可变与不可变数据和深拷贝与浅拷贝
2022/04/06 Python