tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python定时器使用示例分享
Feb 16 Python
python基础教程之常用运算符
Aug 29 Python
Python iter()函数用法实例分析
Mar 17 Python
Python操作Oracle数据库的简单方法和封装类实例
May 07 Python
python爬虫之urllib,伪装,超时设置,异常处理的方法
Dec 19 Python
python实现一个简单的udp通信的示例代码
Feb 01 Python
详解python中TCP协议中的粘包问题
Mar 22 Python
python如何实现从视频中提取每秒图片
Oct 22 Python
深入了解python中元类的相关知识
Aug 29 Python
Python实现隐马尔可夫模型的前向后向算法的示例代码
Dec 31 Python
使用python接受tgam的脑波数据实例
Apr 09 Python
Python爬虫之爬取最新更新的小说网站
May 06 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
上海无线电三厂简史修改版
2021/03/01 无线电
在命令行下运行PHP脚本[带参数]的方法
2010/01/22 PHP
PHP has encountered a Stack overflow问题解决方法
2014/11/03 PHP
php调用mysql存储过程实例分析
2014/12/29 PHP
php数组合并与拆分实例分析
2015/06/12 PHP
高质量PHP代码的50个实用技巧必备(下)
2016/01/22 PHP
Laravel 5.4向IoC容器中添加自定义类的方法示例
2017/08/15 PHP
PHP实现的文件浏览器功能简单示例
2019/09/12 PHP
一个简单的JavaScript 日期计算算法
2009/09/11 Javascript
js实现页面跳转重定向的几种方式
2014/05/29 Javascript
分享28款免费实用的 JQuery 图片和内容滑块插件
2014/12/15 Javascript
JavaScript中5种调用函数的方法
2015/03/12 Javascript
js实现Form栏显示全格式时间时钟效果代码
2015/08/19 Javascript
JQuery异步加载PartialView的方法
2016/06/07 Javascript
NodeJS使用formidable实现文件上传
2016/10/27 NodeJs
微信小程序之仿微信漂流瓶实例
2016/12/09 Javascript
JavaScript requestAnimationFrame动画详解
2017/09/14 Javascript
JavaScript设计模式之责任链模式实例分析
2019/01/16 Javascript
Vue.js组件使用props传递数据的方法
2019/10/19 Javascript
electron+vue实现div contenteditable截图功能
2020/01/07 Javascript
javascript实现获取中文汉字拼音首字母
2020/05/19 Javascript
[57:31]DOTA2-DPC中国联赛 正赛 SAG vs CDEC BO3 第一场 2月1日
2021/03/11 DOTA
进一步探究Python中的正则表达式
2015/04/28 Python
Python实现的根据IP地址计算子网掩码位数功能示例
2018/05/23 Python
Python使用sort和class实现的多级排序功能示例
2018/08/15 Python
用python写一个带有gui界面的密码生成器
2020/11/06 Python
python-地图可视化组件folium的操作
2020/12/14 Python
CSS3 实现雷达扫描图的示例代码
2020/09/21 HTML / CSS
英国排名第一的最新设计师品牌手表独立零售商:TIC Watches
2016/09/24 全球购物
Ben Sherman官方网站:英国男装品牌
2019/10/22 全球购物
索引覆盖(Index Covering)查询含义
2012/02/18 面试题
党员干部2014全国两会学习心得体会
2014/03/10 职场文书
晨会主持词
2014/03/17 职场文书
小学生差生评语
2014/12/29 职场文书
Python极值整数的边界探讨分析
2021/09/15 Python
Python中使用tkFileDialog实现文件选择、保存和路径选择
2022/05/20 Python