tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现socket客户端和服务端简单示例
Feb 24 Python
在Python中操作字典之setdefault()方法的使用
May 21 Python
简单了解python模块概念
Jan 11 Python
Python动态导入模块的方法实例分析
Jun 28 Python
浅谈PyQt5 的帮助文档查找方法,可以查看每个类的方法
Jun 25 Python
Tensorflow模型实现预测或识别单张图片
Jul 19 Python
python list转置和前后反转的例子
Aug 26 Python
python自动结束mysql慢查询会话的实例代码
Oct 27 Python
python GUI库图形界面开发之pyinstaller打包python程序为exe安装文件
Feb 26 Python
Python random库使用方法及异常处理方案
Mar 02 Python
python 使用csv模块读写csv格式文件的示例
Dec 02 Python
解决Python字典查找报Keyerror的问题
May 26 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
php.ini 配置文件的深入解析
2013/06/17 PHP
PHP处理CSV表格文件的常用操作方法总结
2016/07/01 PHP
jquery控制listbox中项的移动并排序
2009/11/12 Javascript
Date对象格式化函数代码
2010/07/17 Javascript
Extjs4实现两个GridPanel之间数据拖拽功能具体方法
2013/11/21 Javascript
angularjs的一些优化小技巧
2014/12/06 Javascript
DOM基础教程之使用DOM控制表单
2015/01/20 Javascript
jquery对dom节点的操作【推荐】
2016/04/15 Javascript
jQuery原理系列-css选择器的简单实现
2016/06/07 Javascript
Jquery揭秘系列:ajax原生js实现详解(推荐)
2016/06/08 Javascript
js数组的五种迭代方法及两种归并方法(推荐)
2016/06/14 Javascript
详解微信小程序Radio选中样式切换
2017/07/06 Javascript
ECMAScript6变量的解构赋值实例详解
2017/09/19 Javascript
Vue.directive 自定义指令的问题小结
2018/03/04 Javascript
Vue.js中的computed工作原理
2018/03/22 Javascript
Node.js模块全局安装路径配置方法
2018/05/17 Javascript
webpack4 CSS Tree Shaking的使用
2018/09/03 Javascript
JavaScript基础教程之如何实现一个简单的promise
2018/09/11 Javascript
python 多线程应用介绍
2012/12/19 Python
python结合API实现即时天气信息
2016/01/19 Python
Python中文分词实现方法(安装pymmseg)
2016/06/14 Python
Python+tkinter使用40行代码实现计算器功能
2018/01/30 Python
python对验证码降噪的实现示例代码
2019/11/12 Python
如何基于python实现画不同品种的樱花树
2020/01/03 Python
Python同时迭代多个序列的方法
2020/07/28 Python
Watch Station官方网站:世界一流的手表和智能手表
2020/01/05 全球购物
澳大利亚在线划船、露营和钓鱼商店:BCF Australia
2020/03/22 全球购物
教师的实习鉴定
2013/12/15 职场文书
班级学习计划书
2014/04/27 职场文书
大型会议策划方案
2014/05/17 职场文书
小学一年级数学教学计划
2015/01/20 职场文书
校本研修个人总结
2015/02/28 职场文书
家庭暴力离婚起诉书
2015/05/18 职场文书
主题班会开场白
2015/06/01 职场文书
暂住证证明
2015/06/19 职场文书
红楼梦读书笔记
2015/06/25 职场文书