tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的错误处理
Apr 10 Python
python二分查找算法的递归实现方法
May 12 Python
Python设计实现的计算器功能完整实例
Aug 18 Python
Python中的默认参数实例分析
Jan 29 Python
Python3爬虫使用Fidder实现APP爬取示例
Nov 27 Python
树莓派用python中的OpenCV输出USB摄像头画面
Jun 22 Python
Python自定义聚合函数merge与transform区别详解
May 26 Python
Django CBV模型源码运行流程详解
Aug 17 Python
Python内置函数property()如何使用
Sep 01 Python
python 自动刷新网页的两种方法
Apr 20 Python
利用Selenium添加cookie实现自动登录的示例代码(fofa)
May 08 Python
pytorch--之halfTensor的使用详解
May 24 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
利用php递归实现无限分类 格式化数组的详解
2013/06/08 PHP
PHP中怎样保持SESSION不过期 原理及方案介绍
2013/08/08 PHP
PHP curl 抓取AJAX异步内容示例
2014/09/09 PHP
php的mssql数据库连接类实例
2014/11/28 PHP
PHP常用技巧汇总
2016/03/04 PHP
YII2 实现多语言配置的方法分享
2017/01/11 PHP
PHP+ajax实现获取新闻数据简单示例
2018/05/08 PHP
Yii2.0框架模型添加/修改/删除数据操作示例
2019/07/18 PHP
jquery提升性能最佳实践小结
2010/12/06 Javascript
Javascript 面试题随笔
2011/03/31 Javascript
也说JavaScript中String类的replace函数
2011/09/22 Javascript
jquery获取radio值(单选组radio)
2014/10/16 Javascript
AngularJS语法详解
2015/01/23 Javascript
NODE.JS跨域问题的完美解决方案
2016/10/20 Javascript
详解JavaScript树结构
2017/01/09 Javascript
js实现音乐播放控制条
2017/09/09 Javascript
jQuery幻灯片插件owlcarousel参数说明中文文档
2018/02/27 jQuery
微信小程序常用简易小函数总结
2019/02/01 Javascript
Echarts实现单条折线可拖拽效果
2019/12/19 Javascript
前端开发之便利店收银系统代码
2019/12/27 Javascript
如何在VUE中使用vue-awesome-swiper
2021/01/04 Vue.js
[02:42]DOTA2城市挑战赛收官在即 四强之争风起云涌
2018/06/05 DOTA
python实现socket端口重定向示例
2014/02/10 Python
Python图像处理之简单画板实现方法示例
2018/08/30 Python
opencv3/C++ 平面对象识别&透视变换方式
2019/12/11 Python
alice McCALL官网:澳大利亚时尚品牌
2020/11/16 全球购物
Quiksilver美国官网:始于1969年的优质冲浪服和滑雪板外套
2020/04/20 全球购物
代码中finally中的代码会不会执行
2012/02/06 面试题
我的大学生活演讲稿
2014/04/25 职场文书
暑期学习心得体会
2014/09/02 职场文书
实习生工作证明范本
2014/09/14 职场文书
应届生简历自我评价
2015/03/11 职场文书
叶问观后感
2015/06/15 职场文书
2015大学生入党个人自传
2015/06/26 职场文书
七一慰问简报
2015/07/20 职场文书
2016入党积极分子心得体会
2016/01/06 职场文书