tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用java的Webservice示例
Mar 10 Python
Python是编译运行的验证方法
Jan 30 Python
python实现获取客户机上指定文件并传输到服务器的方法
Mar 16 Python
Python删除Java源文件中全部注释的实现方法
Aug 30 Python
基于并发服务器几种实现方法(总结)
Dec 29 Python
TensorFlow数据输入的方法示例
Jun 19 Python
python中reader的next用法
Jul 24 Python
Python Opencv实现图像轮廓识别功能
Mar 23 Python
python对绑定事件的鼠标、按键的判断实例
Jul 17 Python
Python基于staticmethod装饰器标示静态方法
Oct 17 Python
python 使用csv模块读写csv格式文件的示例
Dec 02 Python
Python turtle编写简单的球类小游戏
Mar 31 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
php中判断数组是一维,二维,还是多维的解决方法
2013/05/04 PHP
ThinkPHP验证码和分页实例教程
2014/08/22 PHP
PHP Cookie学习笔记
2016/08/23 PHP
php curl中gzip的压缩性能测试实例分析
2016/11/08 PHP
javascript innerHTML、outerHTML、innerText、outerText的区别
2008/11/24 Javascript
利用jQuery的$.event.fix函数统一浏览器event事件处理
2009/12/21 Javascript
Javascript WebSocket使用实例介绍(简明入门教程)
2014/04/16 Javascript
Firefox下无法正常显示年份的解决方法
2014/09/04 Javascript
node.js中的path.sep方法使用说明
2014/12/08 Javascript
JavaScript前端图片加载管理器imagepool使用详解
2014/12/29 Javascript
jQuery编程中的一些核心方法简介
2015/08/14 Javascript
11种ASP连接数据库的方法
2015/09/18 Javascript
AngularJS directive返回对象属性详解
2016/03/28 Javascript
原生javascript实现分享到朋友圈功能 支持ios和android
2016/05/11 Javascript
AngularJs 利用百度地图API 定位当前位置 获取地址信息
2017/01/18 Javascript
jQuery窗口拖动功能的实现代码
2017/02/04 Javascript
微信小程序 wx:for的使用实例详解
2017/04/27 Javascript
vue+vux实现移动端文件上传样式
2017/07/28 Javascript
vue-cli监听组件加载完成的方法
2018/09/07 Javascript
js中实例与对象的区别讲解
2019/01/21 Javascript
vue resource发送请求的几种方式
2019/09/30 Javascript
JavaScript进阶(三)闭包原理与用法详解
2020/05/09 Javascript
Vue通过provide inject实现组件通信
2020/09/03 Javascript
Python splitlines使用技巧
2008/09/06 Python
Python实现从URL地址提取文件名的方法
2015/05/15 Python
Python tkinter事件高级用法实例
2018/01/31 Python
python自制包并用pip免提交到pypi仅安装到本机【推荐】
2019/06/03 Python
Python完成哈夫曼树编码过程及原理详解
2019/07/29 Python
python sorted方法和列表使用解析
2019/11/18 Python
python爬虫筛选工作实例讲解
2020/11/23 Python
一百多行代码实现react拖拽hooks
2021/03/23 Javascript
工程班组长岗位职责
2013/12/30 职场文书
酒店优秀员工推荐信
2015/03/24 职场文书
《卖火柴的小女孩》教学反思
2016/02/19 职场文书
68句权威创业名言
2019/08/26 职场文书
mysql使用instr达到in(字符串)的效果
2022/04/03 MySQL