tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python检测远程端口是否打开的方法
Mar 14 Python
在Django中使用Sitemap的方法讲解
Jul 22 Python
Python实现自动为照片添加日期并分类的方法
Sep 30 Python
解决sublime+python3无法输出中文的问题
Dec 12 Python
Python的高阶函数用法实例分析
Apr 11 Python
Python 解决火狐浏览器不弹出下载框直接下载的问题
Mar 09 Python
pandas读取csv文件提示不存在的解决方法及原因分析
Apr 21 Python
Python爬虫逆向分析某云音乐加密参数的实例分析
Dec 04 Python
十个Python自动化常用操作,即拿即用
May 10 Python
只需要这一行代码就能让python计算速度提高十倍
May 24 Python
python 如何用terminal输入参数
May 25 Python
python百行代码实现汉服圈图片爬取
Nov 23 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
PHP 数字左侧自动补0
2008/03/31 PHP
PHP获取文件的MD5值并判断是否被修改的例子
2014/06/19 PHP
php的闭包(Closure)匿名函数初探
2016/02/14 PHP
php并发加锁示例
2016/10/17 PHP
不安全的常用的js写法
2009/09/15 Javascript
学习ExtJS accordion布局
2009/10/08 Javascript
js使用正则实现ReplaceAll全部替换的方法
2014/07/18 Javascript
JS 使用for循环遍历子节点查找元素
2014/09/06 Javascript
在WordPress中加入Google搜索功能的简单步骤讲解
2016/01/04 Javascript
Node.js程序中的本地文件操作用法小结
2016/03/06 Javascript
jQuery实现鼠标选文字发新浪微博的方法
2016/04/02 Javascript
vue axios用法教程详解
2017/07/23 Javascript
Angular指令之restict匹配模式的详解
2017/07/27 Javascript
JS设计模式之策略模式概念与用法分析
2018/02/05 Javascript
vue 实现复制内容到粘贴板clipboard的方法
2018/03/17 Javascript
使用koa2创建web项目的方法步骤
2019/03/12 Javascript
微信小程序封装分享与分销功能过程解析
2019/08/13 Javascript
原生JS实现顶部导航栏显示按钮+搜索框功能
2019/12/25 Javascript
Element PageHeader页头的使用方法
2020/07/26 Javascript
ant-design-vue中tree增删改的操作方法
2020/11/03 Javascript
如何在现代JavaScript中编写异步任务
2021/01/31 Javascript
[32:47]完美世界DOTA2联赛 GXR vs IO 第二场 11.07
2020/11/09 DOTA
Python中的赋值、浅拷贝、深拷贝介绍
2015/03/09 Python
深入理解Python中命名空间的查找规则LEGB
2015/08/06 Python
Python中在脚本中引用其他文件函数的实现方法
2016/06/23 Python
Python中的集合介绍
2019/01/28 Python
python对一个数向上取整的实例方法
2020/06/18 Python
python等待10秒执行下一命令的方法
2020/07/19 Python
什么是跨站脚本攻击
2014/12/11 面试题
飞利信loadrunner和软件测试笔试题
2012/09/22 面试题
八一建军节部队活动方案
2014/02/04 职场文书
绿色环保演讲稿
2014/05/10 职场文书
机械工程师岗位职责
2014/06/16 职场文书
2014年教务处工作总结
2014/12/03 职场文书
女方家长婚礼答谢词
2015/09/29 职场文书
python基于tkinter实现gif录屏功能
2021/05/19 Python