TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python运用于数据分析的简单教程
Mar 27 Python
Python SQLite3简介
Feb 22 Python
django传值给模板, 再用JS接收并进行操作的实例
May 28 Python
Pycharm取消py脚本中SQL识别的方法
Nov 29 Python
python+opencv实现霍夫变换检测直线
Oct 23 Python
python实现得到当前登录用户信息的方法
Jun 21 Python
用Python将Excel数据导入到SQL Server的例子
Aug 24 Python
Pyecharts绘制全球流向图的示例代码
Jan 08 Python
python opencv实现简易画图板
Aug 27 Python
Django-simple-captcha验证码包使用方法详解
Nov 28 Python
python 基于selectors库实现文件上传与下载
Dec 31 Python
Python机器学习之PCA降维算法详解
May 19 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
基于mysql的论坛(1)
2006/10/09 PHP
php在程序中将网页生成word文档并提供下载的代码
2012/10/09 PHP
PHP自带函数给数字或字符串自动补齐位数
2014/07/29 PHP
php创建、获取cookie及基础要点分析
2015/01/26 PHP
PHP发送AT指令实例代码
2016/05/26 PHP
Swoole实现异步投递task任务案例详解
2019/04/02 PHP
[原创]IE view-source 无法查看看源码 JavaScript看网页源码
2009/07/19 Javascript
javascript预览上传图片发现的问题的解决方法
2010/11/25 Javascript
javascript开发随笔一 preventDefault的必要
2011/11/25 Javascript
JQuery获取各种宽度、高度(format函数)实例
2013/03/04 Javascript
node.js中的fs.ftruncate方法使用说明
2014/12/15 Javascript
jquery简单实现带渐显效果的选项卡菜单代码
2015/09/01 Javascript
HTML中setCapture、releaseCapture 使用方法浅析
2016/09/25 Javascript
整理一些最近经常遇到的前端面试题
2017/04/25 Javascript
JavaScript中document.referrer的用法详解
2017/07/04 Javascript
微信小程序使用map组件实现路线规划功能示例
2019/01/22 Javascript
vue实现表单未编辑或未保存离开弹窗提示功能
2020/04/08 Javascript
javascript实现京东登录显示隐藏密码
2020/08/02 Javascript
Python简单计算数组元素平均值的方法示例
2017/12/26 Python
linux下python使用sendmail发送邮件
2018/05/22 Python
python2.x实现人民币转大写人民币
2018/06/20 Python
pandas 时间格式转换的实现
2019/07/06 Python
基于python操作ES实例详解
2019/11/16 Python
Python错误的处理方法
2020/06/23 Python
Python 合并拼接字符串的方法
2020/07/28 Python
python高级特性简介
2020/08/13 Python
Pycharm连接gitlab实现过程图解
2020/09/01 Python
一款利用html5和css3实现的3D滚动特效的教程
2015/01/04 HTML / CSS
canvas 如何绘制线段的实现方法
2018/07/12 HTML / CSS
美术专业学生个人自我评价
2013/09/19 职场文书
自我鉴定的范文
2013/10/03 职场文书
商务主管岗位职责
2013/12/08 职场文书
初三化学教学反思
2014/01/23 职场文书
大学生职业生涯规划书汇总
2014/03/20 职场文书
经理任命书模板
2014/06/06 职场文书
公安机关纪律作风整顿个人剖析材料材料
2014/10/10 职场文书