TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python数组遍历的简单实现方法小结
Apr 27 Python
Python数据结构与算法之字典树实现方法示例
Dec 13 Python
快速了解Python开发中的cookie及简单代码示例
Jan 17 Python
django 解决manage.py migrate无效的问题
May 27 Python
Python爬虫基础之XPath语法与lxml库的用法详解
Sep 13 Python
Python函数返回不定数量的值方法
Jan 22 Python
Python Django框架模板渲染功能示例
Nov 08 Python
如何在python中写hive脚本
Nov 08 Python
python之array赋值技巧分享
Nov 28 Python
python如何使用jt400.jar包代码实例
Dec 20 Python
使用pytorch实现可视化中间层的结果
Dec 30 Python
Python+Selenium实现抖音、快手、B站、小红书、微视、百度好看视频、西瓜视频、微信视频号、搜狐视频、一点号、大风号、趣头条等短视频自动发布
Apr 13 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
PHP 截取字符串 分别适合GB2312和UTF8编码情况
2009/02/12 PHP
浅谈PHP正则表达式中修饰符/i, /is, /s, /isU
2014/10/21 PHP
浅谈PHP中foreach/in_array的使用
2015/11/02 PHP
PHP实现HTML页面静态化的方法
2015/11/04 PHP
编写PHP脚本来实现WordPress中评论分页的功能
2015/12/10 PHP
php接口技术实例详解
2016/12/07 PHP
php在linux环境中如何使用redis详解
2020/12/15 PHP
Jquery cookie操作代码
2010/03/14 Javascript
myeclipse安装jQuery插件的方法
2011/03/29 Javascript
JS前端框架关于重构的失败经验分享
2013/03/17 Javascript
javascript重写alert方法的实例代码
2013/03/29 Javascript
Javascript实现返回上一页面并刷新的小例子
2013/12/11 Javascript
JQuery插件iScroll实现下拉刷新,滚动翻页特效
2014/06/22 Javascript
JavaScript中的setMilliseconds()方法使用详解
2015/06/11 Javascript
浅析Node.js中使用依赖注入的相关问题及解决方法
2015/06/24 Javascript
JS Ajax请求如何防止重复提交
2016/06/13 Javascript
js实现3D图片展示效果
2017/03/09 Javascript
Bootstrap 表单验证formValidation 实现表单动态验证功能
2017/05/17 Javascript
mui back 返回刷新页面的实例
2017/12/06 Javascript
微信小程序使用slider设置数据值及switch开关组件功能【附源码下载】
2017/12/09 Javascript
详解weex默认webpack.config.js改造
2018/01/08 Javascript
[00:58]2016年国际邀请赛勇士令状宣传片
2016/06/01 DOTA
python itchat实现微信自动回复的示例代码
2017/08/14 Python
Python装饰器用法实例总结
2018/02/07 Python
pandas求两个表格不相交的集合方法
2018/12/08 Python
基于Python+Appium实现京东双十一自动领金币功能
2019/10/31 Python
Python3操作YAML文件格式方法解析
2020/04/10 Python
python 指定源路径来解决import问题的操作
2021/03/04 Python
HTML5的结构和语义(2):结构
2008/10/17 HTML / CSS
HTML5手机端弹出遮罩菜单特效代码
2016/01/27 HTML / CSS
龟牌英国商店:Turtle Wax Brand Store UK
2019/07/02 全球购物
应届生人事助理求职信
2013/11/09 职场文书
关于五一放假的通知
2015/08/18 职场文书
大学迎新生欢迎词
2015/09/29 职场文书
Java如何实现树的同构?
2021/06/22 Java/Android
Vue ECharts实现机舱座位选择展示功能
2022/05/15 Vue.js