tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中输出ASCII大文字、艺术字、字符字小技巧
Apr 28 Python
关于反爬虫的一些简单总结
Dec 13 Python
Python命令行解析模块详解
Feb 01 Python
Python3爬虫之urllib携带cookie爬取网页的方法
Dec 28 Python
Python爬虫实现使用beautifulSoup4爬取名言网功能案例
Sep 15 Python
python将print输出的信息保留到日志文件中
Sep 27 Python
python pycharm的安装及其使用
Oct 11 Python
Tensorflow 实现分批量读取数据
Jan 04 Python
python新手学习可变和不可变对象
Jun 11 Python
基于python实现ROC曲线绘制广场解析
Jun 28 Python
matplotlib对象拾取事件处理的实现
Jan 14 Python
微信小程序调用python模型
Apr 21 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
php visitFile()遍历指定文件夹函数
2010/08/21 PHP
php中CI操作多个数据库的代码
2012/07/05 PHP
php中使用__autoload()自动加载未定义类的实现代码
2013/02/06 PHP
php生成QRcode实例
2014/09/22 PHP
PHP发送短信代码分享
2015/08/11 PHP
javascript 解析url的search方法
2010/02/09 Javascript
在浏览器中获取当前执行的脚本文件名的代码
2011/07/19 Javascript
jQuery 瀑布流 浮动布局(一)(延迟AJAX加载图片)
2012/05/23 Javascript
JS应用正则表达式转换大小写示例
2014/09/18 Javascript
JavaScript将Web页面内容导出到Word及Excel的方法
2015/02/13 Javascript
JS获取图片高度宽度的方法分享
2015/04/17 Javascript
深入分析JSON编码格式提交表单数据
2015/06/25 Javascript
Jquery检验手机号是否符合规则并根据手机号检测结果将提交按钮设为不同状态
2015/11/26 Javascript
Express实现前端后端通信上传图片之存储数据库(mysql)傻瓜式教程(一)
2015/12/10 Javascript
jQuery使用Layer弹出层插件闪退问题
2016/12/22 Javascript
AngularJS constant和value区别详解
2017/02/28 Javascript
浅析JS中的 map, filter, some, every, forEach, for in, for of 用法总结
2017/03/29 Javascript
微信小程序 实例开发总结
2017/04/26 Javascript
详解Node.js开发中的express-session
2017/05/19 Javascript
cocos creator Touch事件应用(触控选择多个子节点的实例)
2017/09/10 Javascript
jQuery实现轮播图及其原理详解
2020/04/12 jQuery
借助云开发实现小程序短信验证码的发送
2020/01/06 Javascript
[49:02]KG vs Infamous 2019国际邀请赛淘汰赛 败者组BO1 8.20.mp4
2020/07/19 DOTA
numpy实现合并多维矩阵、list的扩展方法
2018/05/08 Python
python实现泊松图像融合
2018/07/26 Python
Python多进程原理与用法分析
2018/08/21 Python
解决python3 安装完Pycurl在import pycurl时报错的问题
2018/10/15 Python
python3发送邮件需要经过代理服务器的示例代码
2019/07/25 Python
Pytorch之Variable的用法
2019/12/31 Python
德国箱包网上商店:koffer24.de
2016/07/27 全球购物
我的五年职业生涯规划
2014/01/23 职场文书
2014年教研员工作总结
2014/12/23 职场文书
保送生自荐信
2015/03/06 职场文书
2015年党日活动总结范文
2015/03/25 职场文书
Python3 使用pip安装git并获取Yahoo金融数据的操作
2021/04/08 Python
中国十大神话动漫电影排行榜 哪吒登顶 白蛇缘起排第七
2022/03/21 国漫