tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python监控主机是否存活并以邮件报警
Sep 22 Python
常见python正则用法的简单实例
Jun 21 Python
Python爬取当当、京东、亚马逊图书信息代码实例
Dec 09 Python
利用Python如何实现数据驱动的接口自动化测试
May 11 Python
python 3.6.5 安装配置方法图文教程
Sep 18 Python
pytorch构建多模型实例
Jan 15 Python
基于spring boot 日志(logback)报错的解决方式
Feb 20 Python
jupyter notebook中美观显示矩阵实例
Apr 17 Python
对python pandas中 inplace 参数的理解
Jun 27 Python
一行代码python实现文件共享服务器
Apr 22 Python
用Python实现一个打字速度测试工具来测试你的手速
May 28 Python
python plt.plot bar 如何设置绘图尺寸大小
Jun 01 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
自己做矿石收音机
2021/03/02 无线电
PHP自动重命名文件实现方法
2014/11/04 PHP
php实现scws中文分词搜索的方法
2015/12/25 PHP
php+mysql查询实现无限下级分类树输出示例
2016/10/03 PHP
php实现批量上传数据到数据库(.csv格式)的案例
2017/06/18 PHP
仿微博字符限制效果实现代码
2012/04/20 Javascript
判断多个input type=file是否有已经选择好文件的代码
2012/05/23 Javascript
Js如何判断客户端是PC还是手持设备简单分析
2012/11/22 Javascript
JQuery的$命名冲突详细解析
2013/12/28 Javascript
深入理解JavaScript系列(27):设计模式之建造者模式详解
2015/03/03 Javascript
JS实现FLASH幻灯片图片切换效果的方法
2015/03/04 Javascript
JS作为值的函数用法示例
2016/06/20 Javascript
基于JS模仿windows文件按名称排序效果
2016/06/29 Javascript
利用Javascript实现BMI计算器
2016/08/16 Javascript
js+div+css下拉导航菜单完整代码分享
2016/12/28 Javascript
JavaScript实现为事件句柄绑定监听函数的方法分析
2017/11/14 Javascript
浅谈ng-zorro使用心得
2018/12/03 Javascript
ionic4+angular7+cordova上传图片功能的实例代码
2019/06/19 Javascript
antd-DatePicker组件获取时间值,及相关设置方式
2020/10/27 Javascript
Vue 实现拨打电话操作
2020/11/16 Javascript
Python加pyGame实现的简单拼图游戏实例
2015/05/15 Python
python进程管理工具supervisor的安装与使用教程
2017/09/05 Python
深入理解Python中的super()方法
2017/11/20 Python
python 获取页面表格数据存放到csv中的方法
2018/12/26 Python
Pytorch Tensor基本数学运算详解
2019/12/30 Python
基于python实现百度语音识别和图灵对话
2020/11/02 Python
Python中lru_cache的使用和实现详解
2021/01/25 Python
英国时尚女装购物网站:Missguided
2018/08/23 全球购物
机械个人求职信范文
2014/01/24 职场文书
期末自我鉴定
2014/02/02 职场文书
售后服务经理岗位职责
2014/02/25 职场文书
竞选班干部演讲稿600字
2014/08/20 职场文书
立案决定书范文
2015/06/24 职场文书
python实现腾讯滑块验证码识别
2021/04/27 Python
Python函数中的不定长参数相关知识总结
2021/06/24 Python
浅谈JavaScript作用域
2021/12/06 Javascript