tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python if not in 多条件判断代码
Sep 21 Python
python 数据的清理行为实例详解
Jul 12 Python
python实现远程通过网络邮件控制计算机重启或关机
Feb 22 Python
Python3.6实现连接mysql或mariadb的方法分析
May 18 Python
python pillow模块使用方法详解
Aug 30 Python
使用python远程操作linux过程解析
Dec 04 Python
解决python3插入mysql时内容带有引号的问题
Mar 02 Python
python虚拟环境模块venv使用及示例
Mar 04 Python
xadmin使用formfield_for_dbfield函数过滤下拉表单实例
Apr 07 Python
Python高并发和多线程有什么关系
Nov 14 Python
python实现简单倒计时功能
Apr 21 Python
Django + Taro 前后端分离项目实现企业微信登录功能
Apr 07 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
PHP实现克鲁斯卡尔算法实例解析
2014/08/22 PHP
php实现parent调用父类的构造方法与被覆写的方法
2015/02/11 PHP
PHP操作mysql数据库分表的方法
2016/06/09 PHP
PHP实现的观察者模式实例
2017/06/21 PHP
实例介绍PHP中zip_open()函数用法
2019/02/15 PHP
javascript第一课
2007/02/27 Javascript
js实现权限树的更新权限时的全选全消功能
2009/02/17 Javascript
JavaScript的继承的封装介绍
2013/10/15 Javascript
jQuery实现数字加减效果汇总
2014/12/16 Javascript
12行javascript代码绘制一个八卦图
2015/04/02 Javascript
JavaScript+html5 canvas实现本地截图教程
2020/04/16 Javascript
jQuery插件EasyUI实现Layout框架页面中弹出窗体到最顶层效果(穿越iframe)
2016/08/05 Javascript
JS简单实现浮动窗口效果示例
2016/09/07 Javascript
简单的渐变轮播插件
2017/01/12 Javascript
js 单引号替换成双引号,双引号替换成单引号的实现方法
2017/02/16 Javascript
vuejs开发组件分享之H5图片上传、压缩及拍照旋转的问题处理
2017/03/06 Javascript
JS判断非空至少输入两个字符的简单实现方法
2017/06/23 Javascript
JavaScript 中使用 Generator的方法
2017/12/29 Javascript
浅谈Vue内置component组件的应用场景
2018/03/27 Javascript
Vuex的基本概念、项目搭建以及入坑点
2018/11/04 Javascript
微信小程序实现watch监听
2020/06/04 Javascript
[02:14]完美“圣”典2016风云人物:xiao8专访
2016/12/01 DOTA
在Python下尝试多线程编程
2015/04/28 Python
python中黄金分割法实现方法
2015/05/06 Python
python类中super()和__init__()的区别
2016/10/18 Python
Win10下python 2.7.13 安装配置方法图文教程
2018/09/18 Python
virtualenv 指定 python 解释器的版本方法
2018/10/25 Python
解决Python正则表达式匹配反斜杠''\''问题
2019/07/17 Python
使用python实现kNN分类算法
2019/10/16 Python
在 Linux/Mac 下为Python函数添加超时时间的方法
2020/02/20 Python
使用Python合成图片的实现代码(图片添加个性化文本,图片上叠加其他图片)
2020/04/30 Python
Boom手表官网:瑞典手表品牌,设计你的手表
2019/03/11 全球购物
党的群众路线教育实践活动学习计划
2014/11/03 职场文书
2014年个人思想工作总结
2014/11/27 职场文书
2014预防青少年违法犯罪工作总结
2014/12/10 职场文书
可怜妈妈观后感
2015/06/09 职场文书