解决tensorflow模型参数保存和加载的问题


Posted in Python onJuly 26, 2018

终于找到bug原因!记一下;还是不熟悉平台的原因造成的!

Q:为什么会出现两个模型对象在同一个文件中一起运行,当直接读取他们分开运行时训练出来的模型会出错,而且总是有一个正确,一个读取错误? 而 直接在同一个文件又训练又重新加载模型预测不出错,而且更诡异的是此时用分文件里的对象加载模型不会出错?

model.py,里面含有 ModelV 和 ModelP,另外还有 modelP.py 和 modelV.py 分别只含有 ModelP 和 ModeV 这两个对象,先使用 modelP.py 和 modelV.py 分别训练好模型,然后再在 model.py 里加载进来:

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

 def __init__(self):

  self.v1 = tf.Variable(66, name="v1")
  self.v2 = tf.Variable(77, name="v2")
  self.save_path = "model_v/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.saver = tf.train.Saver()
  self.sess = tf.Session()

 def train(self):
  self.sess.run(self.init)
  print 'v2', self.v2.eval(self.sess)

  self.saver.save(self.sess, self.save_path)
  print "ModelV saved."

 def predict(self):

  all_vars = tf.trainable_variables()
  for v in all_vars:
   print(v.name)
  self.saver.restore(self.sess, self.save_path)
  print "ModelV restored."
  print 'v2', self.v2.eval(self.sess)
  print '------------------------------------------------------------------'

class ModelP():

 def __init__(self):

  self.p1 = tf.Variable(88, name="p1")
  self.p2 = tf.Variable(99, name="p2")
  self.save_path = "model_p/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.saver = tf.train.Saver()
  self.sess = tf.Session()

 def train(self):
  self.sess.run(self.init)
  print 'p2', self.p2.eval(self.sess)

  self.saver.save(self.sess, self.save_path)
  print "ModelP saved."

 def predict(self):

  all_vars = tf.trainable_variables()
  for v in all_vars:
   print v.name
  self.saver.restore(self.sess, self.save_path)
  print "ModelP restored."
  print 'p2', self.p2.eval(self.sess)
  print '---------------------------------------------------------------------'


if __name__ == '__main__':
 v = ModelV()
 p = ModelP()
 v.predict()
 #v.train()
 p.predict() 
 #p.train()

这里 tf.global_variables_initializer() 很关键! 尽管你是分别在对象 ModelP 和 ModelV 内部分配和定义的 tf.Variable(),即 v1 v2 和 p1 p2,但是 对 tf 这个模块而言, 这些都是全局变量,可以通过以下代码查看所有的变量,你就会发现同一个文件中同时运行 ModelP 和 ModelV 在初始化之后都打印出了一样的变量,这个是问题的关键所在:

all_vars = tf.trainable_variables()
for v in all_vars:
 print(v.name)

错误。你可以交换 modelP 和 modelV 初始化的顺序,看看错误信息的变化

v1:0
v2:0
p1:0
p2:0
ModelV restored.
v2 77
v1:0
v2:0
p1:0
p2:0
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v2 not found in checkpoint
W tensorflow/core/framework/op_kernel.cc:975] Not found: Key v1 not found in checkpoint

实际上,分开运行时,模型保存的参数是正确的,因为在一个模型里的Variable就只有 v1 v2 或者 p1 p2; 但是在一个文件同时运行的时候,模型参数实际上保存的是 v1 v2 p1 p2四个,因为在默认情况下,创建的Saver,会直接保存所有的参数。而 Saver.restore() 又是默认(无Variable参数列表时)按照已经定义好的全局模型变量来加载对应的参数值, 在进行 ModelV.predict时,按照顺序(从debug可以看出,应该是按照参数顺序一次检测)在模型文件中查找相应的 key,此时能够找到对应的v1 v2,加载成功,但是在 ModelP.predict时,在model_p的模型文件中找不到 v1 和 v2,只有 p1 和 p2, 此时就会报错;不过这里的 第一次加载 还有 p1 p2 找不到没有报错,解释不通, 未完待续

Saver.save() 和 Saver.restore() 是一对, 分别只保存和加载模型的参数, 但是模型的结构怎么知道呢? 必须是你定义好了,而且要和保存的模型匹配才能加载;

如果想要在不定义模型的情况下直接加载出模型结构和模型参数值,使用

# 加载 结构,即 模型参数 变量等
new_saver = tf.train.import_meta_graph("model_v/model.ckpt.meta")
print "ModelV construct"
all_vars = tf.trainable_variables()
for v in all_vars:
 print v.name
 #print v.name,v.eval(self.sess) # v 都还未初始化,不能求值
# 加载模型 参数变量 的 值
new_saver.restore(self.sess, tf.train.latest_checkpoint('model_v/'))
print "ModelV restored."
all_vars = tf.trainable_variables()
for v in all_vars:
 print v.name,v.eval(self.sess)

加载 结构,即 模型参数 变量等完成后,就会有变量了,但是不能访问他的值,因为还未赋值,然后再restore一次即可得到值了

那么上述错误的解决方法就是这个改进版本的model.py;其实 tf.train.Saver 是可以带参数的,他可以保存你想要保存的模型参数,如果不带参数,很可能就会保存 tf.trainable_variables() 所有的variable,而 tf.trainable_variables()又是从 tf 全局得到的,因此只要在模型保存和加载时,构造对应的带参数的tf.train.Saver即可,这样就会保存和加载正确的模型了

# -*- coding: utf8 -*-

import tensorflow as tf

class ModelV():

 def __init__(self):

  self.v1 = tf.Variable(66, name="v1")
  self.v2 = tf.Variable(77, name="v2")
  self.save_path = "model_v/model.ckpt"
  self.init = tf.global_variables_initializer()

  self.sess = tf.Session()

 def train(self):
  saver = tf.train.Saver([self.v1, self.v2])
  self.sess.run(self.init)
  print 'v2', self.v2.eval(self.sess)

  saver.save(self.sess, self.save_path)
  print "ModelV saved."

 def predict(self):
  saver = tf.train.Saver([self.v1, self.v2])
  all_vars = tf.trainable_variables()
  for v in all_vars:
   print v.name

  v_vars = [v for v in all_vars if v.name == 'v1:0' or v.name == 'v2:0']
  print "ModelV restored."
  saver.restore(self.sess, self.save_path)
  for v in v_vars:
   print v.name,v.eval(self.sess) 
  print 'v2', self.v2.eval(self.sess)
  print '------------------------------------------------------------------'

class ModelP():

 def __init__(self):

  self.p1 = tf.Variable(88, name="p1")
  self.p2 = tf.Variable(99, name="p2")
  self.save_path = "model_p/model.ckpt"
  self.init = tf.global_variables_initializer()
  self.sess = tf.Session()

 def train(self):
  saver = tf.train.Saver([self.p1, self.p2])
  self.sess.run(self.init)
  print 'p2', self.p2.eval(self.sess)

  saver.save(self.sess, self.save_path)
  print "ModelP saved."

 def predict(self):
  saver = tf.train.Saver([self.p1, self.p2])
  all_vars = tf.trainable_variables()
  p_vars = [v for v in all_vars if v.name == 'p1:0' or v.name == 'p2:0']
  for v in all_vars:
   print v.name
   #print v.name,v.eval(self.sess)
  saver.restore(self.sess, self.save_path)
  print "ModelP restored."
  for p in p_vars:
   print p.name,p.eval(self.sess)
  print 'p2', self.p2.eval(self.sess)
  print '----------------------------------------------------------'


if __name__ == '__main__':
 v = ModelV()
 p = ModelP()
 v.predict()
 #v.train()
 p.predict() 
 #p.train()

小结: 构造的Saver 最好带Variable参数,这样保证 保存和加载能够正确执行

以上这篇解决tensorflow模型参数保存和加载的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中type的构造函数参数含义说明
Jun 21 Python
python如何在终端里面显示一张图片
Aug 17 Python
Python实现字符串与数组相互转换功能示例
Sep 22 Python
python pandas 对时间序列文件处理的实例
Jun 22 Python
符合语言习惯的 Python 优雅编程技巧【推荐】
Sep 25 Python
Python中turtle库的使用实例
Sep 09 Python
python将四元数变换为旋转矩阵的实例
Dec 04 Python
python numpy数组中的复制知识解析
Feb 03 Python
使用Python中tkinter库简单gui界面制作及打包成exe的操作方法(二)
Oct 12 Python
Python操控mysql批量插入数据的实现方法
Oct 27 Python
Lombok插件安装(IDEA)及配置jar包使用详解
Nov 04 Python
详解python网络进程
Jun 15 Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
TensorFlow利用saver保存和提取参数的实例
Jul 26 #Python
You might like
PHP图片验证码制作实现分享(全)
2012/05/10 PHP
基于PHP静态类的原罪详解
2013/05/06 PHP
PHP实现删除字符串中任何字符的函数
2015/08/11 PHP
php通过curl添加cookie伪造登陆抓取数据的方法
2016/04/02 PHP
PHP生成静态HTML文档实现代码
2016/06/23 PHP
PHP大文件分割上传 PHP分片上传
2017/08/28 PHP
Java 正则表达式学习总结和一些小例子
2012/09/13 Javascript
禁用页面部分JavaScript方法的具体实现
2013/07/31 Javascript
js实现Select头像选择实时预览代码
2015/08/17 Javascript
每天一篇javascript学习小结(String对象)
2015/11/18 Javascript
JAVA Web实时消息后台服务器推送技术---GoEasy
2016/11/04 Javascript
Vue.js教程之计算属性
2016/11/11 Javascript
关于JS与jQuery中的文档加载问题
2017/08/22 jQuery
解决vue打包之后静态资源图片失效的问题
2018/02/21 Javascript
element-ui 表格实现单元格可编辑的示例
2018/02/26 Javascript
Nodejs实现爬虫抓取数据实例解析
2018/07/05 NodeJs
vue实现动态显示与隐藏底部导航的方法分析
2019/02/11 Javascript
react PropTypes校验传递的值操作示例
2020/04/28 Javascript
Vue 3自定义指令开发的相关总结
2021/01/29 Vue.js
Python和Ruby中each循环引用变量问题(一个隐秘BUG?)
2014/06/04 Python
Python实现pdf文档转txt的方法示例
2018/01/19 Python
python 对象和json互相转换方法
2018/03/22 Python
python使用tkinter库实现五子棋游戏
2019/06/18 Python
Django --Xadmin 判断登录者身份实例
2020/07/03 Python
德国旅行、体验和活动的预订平台:Watado
2019/12/04 全球购物
意大利消费电子产品购物网站:SLG Store
2019/12/26 全球购物
如何开发安全的AJAX应用
2014/03/26 面试题
外贸公司实习自我鉴定
2013/09/24 职场文书
国际商务专业学生个人的自我评价
2013/09/28 职场文书
个人求职简历的自我评价
2013/10/19 职场文书
工厂实习感言
2014/01/14 职场文书
社区义诊活动总结
2014/04/30 职场文书
庆祝教师节标语
2014/10/09 职场文书
学校政风行风自查自纠报告
2014/10/21 职场文书
Element-ui Layout布局(Row和Col组件)的实现
2021/12/06 Vue.js
maven依赖的version声明控制方式
2022/01/18 Java/Android