tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
让 python 命令行也可以自动补全
Nov 30 Python
Python中返回字典键的值的values()方法使用
May 22 Python
python 与GO中操作slice,list的方式实例代码
Mar 20 Python
浅谈python中拼接路径os.path.join斜杠的问题
Oct 23 Python
对python中数据集划分函数StratifiedShuffleSplit的使用详解
Dec 11 Python
pandas的to_datetime时间转换使用及学习心得
Aug 11 Python
python3实现的zip格式压缩文件夹操作示例
Aug 17 Python
python实现大量图片重命名
Mar 23 Python
Python实现socket非阻塞通讯功能示例
Nov 06 Python
python 使用多线程创建一个Buffer缓存器的实现思路
Jul 02 Python
Python3 用什么IDE开发工具比较好
Nov 28 Python
python爬虫scrapy框架之增量式爬虫的示例代码
Feb 26 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
php数组总结篇(一)
2008/09/30 PHP
PHP判断IP并转跳到相应城市分站的方法
2015/03/25 PHP
php实现页面纯静态的实例代码
2017/06/21 PHP
关于PHP中协程和阻塞的一些理解与思考
2017/08/11 PHP
再谈querySelector和querySelectorAll的区别与联系
2012/04/20 Javascript
再探JavaScript作用域
2014/09/24 Javascript
js对象基础实例分析
2015/01/13 Javascript
JavaScript每天必学之事件
2016/09/18 Javascript
详解Nodejs基于mongoose模块的增删改查的操作
2016/12/21 NodeJs
使用contextMenu插件实现Bootstrap table弹出右键菜单
2017/02/20 Javascript
jQuery模拟窗口抖动效果
2017/03/15 Javascript
JavaScript实现文件下载并重命名代码实例
2019/12/12 Javascript
JS几个常用的函数和对象定义与用法示例
2020/01/15 Javascript
[04:10]2016国际邀请赛中国区预选赛第二日TOP10精彩集锦
2016/06/28 DOTA
[01:01]青春无憾,一战成名——DOTA2全国高校联赛开启
2018/02/25 DOTA
用Python计算三角函数之atan()方法的使用
2015/05/15 Python
Python中if __name__ == '__main__'作用解析
2015/06/29 Python
Python网络编程之TCP与UDP协议套接字用法示例
2018/02/02 Python
Python实现自动上京东抢手机
2018/02/06 Python
python验证码识别教程之利用投影法、连通域法分割图片
2018/06/04 Python
Python 实现「食行生鲜」签到领积分功能
2018/09/26 Python
Python中的正则表达式与JSON数据交换格式
2019/07/03 Python
wxPython色环电阻计算器
2019/11/18 Python
pytorch中的自定义数据处理详解
2020/01/06 Python
Python flask框架如何显示图像到web页面
2020/06/03 Python
利用canvas实现图片压缩的示例代码
2018/07/17 HTML / CSS
中国医药集团国药在线:国药网
2017/02/06 全球购物
澳大利亚便宜隐形眼镜购买网站:QUICKLENS Australia
2018/10/06 全球购物
群众路线教育实践活动的心得体会
2014/09/03 职场文书
2014年机关党建工作总结
2014/11/11 职场文书
党员评议自我评价
2015/03/03 职场文书
2015年幼师工作总结
2015/04/28 职场文书
倡议书的格式写法
2015/04/28 职场文书
创业计划书之蛋糕店
2019/08/29 职场文书
浅谈Python类的单继承相关知识
2021/05/12 Python
关于vue-router-link选择样式设置
2022/04/30 Vue.js