tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 判断自定义对象类型
Mar 21 Python
python 七种邮件内容发送方法实例
Apr 22 Python
wxPython窗口的继承机制实例分析
Sep 28 Python
对python使用http、https代理的实例讲解
May 07 Python
Python读写及备份oracle数据库操作示例
May 17 Python
python3基于TCP实现CS架构文件传输
Jul 28 Python
Python 输出时去掉列表元组外面的方括号与圆括号的方法
Dec 24 Python
Python基本数据结构与用法详解【列表、元组、集合、字典】
Mar 23 Python
python自动分箱,计算woe,iv的实例代码
Nov 22 Python
python3 动态模块导入与全局变量使用实例
Dec 22 Python
python利用后缀表达式实现计算器功能
Feb 22 Python
Django框架中视图的用法
Jun 10 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
php email邮箱正则
2008/10/08 PHP
php cookie的操作实现代码(登录)
2010/12/29 PHP
如何解决phpmyadmin导入数据库文件最大限制2048KB
2015/10/09 PHP
PHP提取字符串中的手机号正则表达式怎么写
2017/07/17 PHP
锋利的jQuery 要点归纳(三) jQuery中的事件和动画(下:动画篇)
2010/03/24 Javascript
JavaScript 面向对象的之私有成员和公开成员
2010/05/04 Javascript
jQuery UI的Dialog无法提交问题的解决方法
2011/01/11 Javascript
js将json格式内容转换成对象的方法
2013/11/01 Javascript
javascript和jquery实现设置和移除文本框默认值效果代码
2015/01/13 Javascript
浅谈使用MVC模式进行JavaScript程序开发
2015/11/10 Javascript
jQuery插件Easyui设置datagrid的pageNumber导致两次请求问题的解决方法
2016/08/06 Javascript
原生js编写基于面向对象的分页组件
2016/12/05 Javascript
JS完成画圆圈的小球
2017/03/07 Javascript
javascript面向对象创建对象的方式小结
2019/07/29 Javascript
Node.js创建一个Express服务的方法详解
2020/01/06 Javascript
vue $mount 和 el的区别说明
2020/09/11 Javascript
[04:47]DOTA2-潍坊风行电子俱乐部探秘
2014/08/08 DOTA
[43:18]NB vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.22
2019/09/05 DOTA
[50:02]完美世界DOTA2联赛循环赛 Magma vs IO BO2第一场 11.01
2020/11/02 DOTA
对于Python的框架中一些会话程序的管理
2015/04/20 Python
Python 如何访问外围作用域中的变量
2016/09/11 Python
python实现Adapter模式实例代码
2018/02/09 Python
pandas 小数位数 精度的处理方法
2018/06/09 Python
Python argparse模块使用方法解析
2020/02/20 Python
CSS3轻松实现清新 Loading 效果的简单实例
2016/06/06 HTML / CSS
使用css3实现超炫的loading加载动画效果
2014/05/07 HTML / CSS
html5本地存储_动力节点Java学院整理
2017/07/12 HTML / CSS
香港中原电器网上商店:Chung Yuen
2019/06/26 全球购物
德国在线香料制造商:Gewürzland
2020/03/10 全球购物
财务会计专业求职信范文
2013/12/31 职场文书
幼儿园六一儿童节活动方案
2014/08/26 职场文书
立志成才演讲稿
2014/09/04 职场文书
房屋鉴定委托书范本
2014/09/23 职场文书
给客户的检讨书
2014/12/21 职场文书
催款通知书范文
2015/04/17 职场文书
使用react+redux实现计数器功能及遇到问题
2021/06/02 Javascript