tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python程序中操作文件之flush()方法的使用教程
May 24 Python
centos6.7安装python2.7.11的具体方法
Jan 16 Python
Python django实现简单的邮件系统发送邮件功能
Jul 14 Python
Python 使用with上下文实现计时功能
Mar 09 Python
简单了解python列表和元组的区别
May 14 Python
Python 判断时间是否在时间区间内的实例
May 16 Python
Python替换NumPy数组中大于某个值的所有元素实例
Jun 08 Python
python 使用paramiko模块进行封装,远程操作linux主机的示例代码
Dec 03 Python
Python 实现RSA加解密文本文件
Dec 30 Python
python 实现图片裁剪小工具
Feb 02 Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 Python
Python中使用Lambda函数的5种用法
Apr 01 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
PHP 只允许指定IP访问(允许*号通配符过滤IP)
2014/07/08 PHP
phpstorm 正则匹配删除空行、注释行(替换注释行为空行)
2018/01/21 PHP
基于jQuery的简单的列表导航菜单
2011/03/02 Javascript
通过JS自动隐藏手机浏览器的地址栏实现原理与代码
2013/01/02 Javascript
Javascript代码在页面加载时的执行顺序介绍
2013/05/03 Javascript
jquery点击页面任何区域实现鼠标焦点十字效果
2013/06/21 Javascript
js浮点数精确计算(加、减、乘、除)
2013/12/26 Javascript
原生javascript实现的分页插件pagenav
2014/08/28 Javascript
jQuery实现的登录浮动框效果代码
2015/09/26 Javascript
Javascript同时声明一连串(多个)变量的方法
2017/01/23 Javascript
Angular搜索 过滤 批量删除 添加 表单验证功能集锦(实例代码)
2017/10/25 Javascript
JavaScript实现二叉树的先序、中序及后序遍历方法详解
2017/10/26 Javascript
vue中promise的使用及异步请求数据的方法
2018/11/08 Javascript
vue开发环境配置跨域的方法步骤
2019/01/16 Javascript
微信小程序实现手势滑动效果
2019/08/26 Javascript
[01:32]DOTA2次级联赛——首支职业女子战队选拔赛全记录
2014/10/23 DOTA
python 输出一个两行字符的变量
2009/02/05 Python
浅谈python多线程和队列管理shell程序
2015/08/04 Python
Python的Flask框架标配模板引擎Jinja2的使用教程
2016/07/12 Python
python下10个简单实例代码
2017/11/15 Python
Django框架会话技术实例分析【Cookie与Session】
2019/05/24 Python
详解pandas使用drop_duplicates去除DataFrame重复项参数
2019/08/01 Python
Python判断字符串是否为空和null方法实例
2020/04/26 Python
Python新手学习标准库模块命名
2020/05/29 Python
Python如何给你的程序做性能测试
2020/07/29 Python
英国评分最高的女性剃须刀订阅盒:FFS Beauty
2018/01/25 全球购物
Coggles美国/加拿大:高级国际时装零售商
2018/10/23 全球购物
采购主管工作职责
2013/12/12 职场文书
孝老爱亲模范事迹
2014/01/24 职场文书
大班开学家长寄语
2014/04/04 职场文书
毕业生工作求职信
2014/06/30 职场文书
大学生找工作求职信
2014/07/09 职场文书
民政工作个人总结
2015/02/28 职场文书
青年干部培训班学习心得体会
2016/01/06 职场文书
十一月早安语录:把心放轻,人生就是一朵自在的云
2019/11/04 职场文书
2007年老电脑安装win11会怎么样? 网友实测win11在老电脑运行良好
2021/11/21 数码科技