tensorflow模型保存、加载之变量重命名实例


Posted in Python onJanuary 21, 2020

话不多说,干就完了。

变量重命名的用处?

简单定义:简单来说就是将模型A中的参数parameter_A赋给模型B中的parameter_B

使用场景:当需要使用已经训练好的模型参数,尤其是使用别人训练好的模型参数时,往往别人模型中的参数命名方式与自己当前的命名方式不同,所以在加载模型参数时需要对参数进行重命名,使得代码更简洁易懂。

实现方法:

1)、模型保存

import os
import tensorflow as tf
 
weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                            mean=0.0,
                            stddev=0.1),
           dtype=tf.float32,
           name="weights")
biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
           dtype=tf.float32,
           name="biases")
 
weights_2 = tf.Variable(initial_value=weights.initialized_value(),
            dtype=tf.float32,
            name="weights_2")
 
# saver checkpoint
if os.path.exists("checkpoints") is False:
  os.makedirs("checkpoints")
 
saver = tf.train.Saver()
with tf.Session() as sess:
  init_op = [tf.global_variables_initializer()]
  sess.run(init_op)
  saver.save(sess=sess, save_path="checkpoints/variable.ckpt")

2)、模型加载(变量名称保持不变)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
def restore_variable(sess):
  # need not initilize variable, but need to define the same variable like checkpoint
  weights = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="weights")
  biases = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="biases")
 
  weights_2 = tf.Variable(initial_value=weights.initialized_value(),
              dtype=tf.float32,
              name="weights_2")
 
  saver = tf.train.Saver()
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  weights_val, weights_2_val = sess.run(
    [
      tf.reshape(weights, shape=[2048]),
      tf.reshape(weights_2, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(weights_val))], weights_val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(weights_2_val))], weights_2_val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable(sess)

3)、模型加载(变量重命名)

import tensorflow as tf
from matplotlib import pyplot as plt
import os
 
current_path = os.path.dirname(os.path.abspath(__file__))
 
 
def restore_variable_renamed(sess):
  conv1_w = tf.Variable(initial_value=tf.truncated_normal(shape=[1024, 2],
                              mean=0.0,
                              stddev=0.1),
             dtype=tf.float32,
             name="conv1_w")
  conv1_b = tf.Variable(initial_value=tf.zeros(shape=[2]),
             dtype=tf.float32,
             name="conv1_b")
 
  conv2_w = tf.Variable(initial_value=conv1_w.initialized_value(),
             dtype=tf.float32,
             name="conv2_w")
 
  # variable named 'weights' in ckpt assigned to current variable conv1_w
  # variable named 'biases' in ckpt assigned to current variable conv1_b
  # variable named 'weights_2' in ckpt assigned to current variable conv2_w
  saver = tf.train.Saver({
    "weights": conv1_w,
    "biases": conv1_b,
    "weights_2": conv2_w
  })
 
  ckpt_path = os.path.join(current_path, "checkpoints", "variable.ckpt")
  saver.restore(sess=sess, save_path=ckpt_path)
 
  conv1_w__val, conv2_w__val = sess.run(
    [
      tf.reshape(conv1_w, shape=[2048]),
      tf.reshape(conv2_w, shape=[2048])
    ]
  )
 
  plt.subplot(1, 2, 1)
  plt.scatter([i for i in range(len(conv1_w__val))], conv1_w__val)
  plt.subplot(1, 2, 2)
  plt.scatter([i for i in range(len(conv2_w__val))], conv2_w__val)
  plt.show()
 
 
if __name__ == '__main__':
  with tf.Session() as sess:
    restore_variable_renamed(sess)

总结:

# 之前模型中叫 'weights'的变量赋值给当前的conv1_w变量

# 之前模型中叫 'biases' 的变量赋值给当前的conv1_b变量

# 之前模型中叫 'weights_2'的变量赋值给当前的conv2_w变量

saver = tf.train.Saver({

"weights": conv1_w,

"biases": conv1_b,

"weights_2": conv2_w

})

以上这篇tensorflow模型保存、加载之变量重命名实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用smtplib和email封装python发送邮件模块类分享
Feb 17 Python
python中的函数用法入门教程
Sep 02 Python
linux平台使用Python制作BT种子并获取BT种子信息的方法
Jan 20 Python
教你使用python实现微信每天给女朋友说晚安
Mar 23 Python
python3第三方爬虫库BeautifulSoup4安装教程
Jun 19 Python
python中正则表达式 re.findall 用法
Oct 23 Python
为什么说Python可以实现所有的算法
Oct 04 Python
python FTP批量下载/删除/上传实例
Dec 22 Python
Python 使用生成器代替线程的方法
Aug 04 Python
详解基于python的图像Gabor变换及特征提取
Oct 26 Python
pycharm 快速解决python代码冲突的问题
Jan 15 Python
Python装饰器的练习题
Nov 23 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
You might like
咖啡常见的种类
2021/03/03 新手入门
PHP循环语句笔记(foreach,list)
2011/11/29 PHP
Thinkphp和onethink实现微信支付插件
2016/04/13 PHP
Yii2前后台分离及migrate使用(七)
2016/05/04 PHP
24条货真价实的PHP代码优化技巧
2016/07/28 PHP
Web开发者必备的12款超赞jQuery插件
2010/12/03 Javascript
使用JS读秒使用示例
2013/09/21 Javascript
JavaScript实现图片自动加载的瀑布流效果
2016/04/11 Javascript
JavaScript 弹出子窗体并返回结果到父窗体的实现代码
2016/05/28 Javascript
plupload+artdialog实现多平台上传文件
2016/07/19 Javascript
微信小程序 progress组件详解及实例代码
2016/10/25 Javascript
jQuery实现div跟随鼠标移动
2020/08/20 jQuery
Vue项目中quill-editor带样式编辑器的使用方法
2017/08/08 Javascript
angularjs 动态从后台获取下拉框的值方法
2018/08/13 Javascript
jQuery序列化form表单数据为JSON对象的实现方法
2018/09/20 jQuery
nuxt中使用路由守卫的方法步骤
2019/01/27 Javascript
Vue项目history模式下微信分享爬坑总结
2019/03/29 Javascript
常见的浏览器存储方式(cookie、localStorage、sessionStorage)
2019/05/07 Javascript
Vue双向数据绑定(MVVM)的原理
2020/10/03 Javascript
通过JS判断网页是否为手机打开
2020/10/28 Javascript
html+vue.js 实现漂亮分页功能可兼容IE
2020/11/07 Javascript
[00:56]跨越时空加入战场 全新祈求者身心“失落奇艺侍祭”展示
2019/07/20 DOTA
python通过自定义isnumber函数判断字符串是否为数字的方法
2015/04/23 Python
python 中的int()函数怎么用
2017/10/17 Python
pandas DataFrame实现几列数据合并成为新的一列方法
2018/06/08 Python
使用Scrapy爬取动态数据
2018/10/21 Python
Python os.access()用法实例
2019/02/18 Python
eclipse创建python项目步骤详解
2019/05/10 Python
pyqt5实现绘制ui,列表窗口,滚动窗口显示图片的方法
2019/06/20 Python
Python的轻量级ORM框架peewee使用教程
2021/02/05 Python
CPB肌肤之钥美国官网:Clé de Peau Beauté
2017/09/05 全球购物
金融学专科生自我鉴定
2014/02/21 职场文书
廉洁使者实施方案
2014/03/29 职场文书
租房协议书范例
2014/10/14 职场文书
教你如何使用Python开发一个钉钉群应答机器人
2021/06/21 Python
用PYTHON去计算88键钢琴的琴键频率和音高
2022/04/10 Python