tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中操作字符串之rstrip()方法的使用
May 19 Python
基于python历史天气采集的分析
Feb 14 Python
Python Datetime模块和Calendar模块用法实例分析
Apr 15 Python
Pycharm简单使用教程(入门小结)
Jul 04 Python
详解如何从TensorFlow的mnist数据集导出手写体数字图片
Aug 05 Python
opencv3/python 鼠标响应操作详解
Dec 11 Python
python numpy数组中的复制知识解析
Feb 03 Python
python GUI库图形界面开发之PyQt5不规则窗口实现与显示GIF动画的详细方法与实例
Mar 09 Python
Django表单提交后实现获取相同name的不同value值
May 14 Python
Python常用断言函数实例汇总
Nov 30 Python
Python编写万花尺图案实例
Jan 03 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
PHP 图片水印类代码
2012/08/27 PHP
php构造函数实例讲解
2013/11/13 PHP
PHP在网页中动态生成PDF文件详细教程
2014/07/05 PHP
Smarty变量调节器失效的解决办法
2014/08/20 PHP
php 无限分类 树形数据格式化代码
2016/10/11 PHP
Laravel中如何轻松容易的输出完整的SQL语句
2020/07/26 PHP
PHP实现简单注册登录系统
2020/12/28 PHP
ajax无刷新动态调用股票信息(改良版)
2008/11/01 Javascript
JavaScript Distilled 基础知识与函数
2010/04/07 Javascript
Js 时间函数getYear()的使用问题探讨
2013/04/01 Javascript
node.js调用C++开发的模块实例
2015/07/03 Javascript
基于javascript实现图片滑动效果
2016/05/07 Javascript
js获取鼠标点击的对象,点击另一个按钮删除该对象的实现代码
2016/05/13 Javascript
详解js中的apply与call的用法
2016/07/30 Javascript
js实现上传图片预览方法
2016/10/25 Javascript
详解JS: reduce方法实现 webpack多文件入口
2017/02/14 Javascript
js自定义input文件上传样式
2018/10/26 Javascript
ES10的13个新特性示例(小结)
2019/09/23 Javascript
python使用递归解决全排列数字示例
2014/02/11 Python
Python中Random和Math模块学习笔记
2015/05/18 Python
python构建深度神经网络(续)
2018/03/10 Python
selenium+python实现自动化登录的方法
2018/09/04 Python
Linux下python3.6.1环境配置教程
2018/09/26 Python
详解用python计算阶乘的几种方法
2019/08/14 Python
高考考python编程是真的吗
2020/07/20 Python
python 装饰器重要在哪
2021/02/14 Python
html5的新玩法——语音搜索
2013/01/03 HTML / CSS
华为俄罗斯官方网上商城:购买Huawei手机和平板
2017/04/21 全球购物
应届本科生推荐信范文
2013/12/25 职场文书
公司道歉信范文
2014/01/09 职场文书
六一儿童节活动策划方案
2014/01/27 职场文书
公务员中国梦演讲稿
2014/08/19 职场文书
商务英语专业大学生职业生涯规划书
2014/09/14 职场文书
努力工作保证书
2015/02/28 职场文书
巾帼建功标兵先进事迹材料
2016/02/29 职场文书
SpringBoot 集成短信和邮件 以阿里云短信服务为例
2022/04/22 Java/Android