TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python删除文件示例分享
Jan 28 Python
Python常见数据结构详解
Jul 24 Python
浅析Python中的join()方法的使用
May 19 Python
Python实现比较两个文件夹中代码变化的方法
Jul 10 Python
python实现决策树C4.5算法详解(在ID3基础上改进)
May 31 Python
Python编程scoketServer实现多线程同步实例代码
Jan 29 Python
python3.6 如何将list存入txt后再读出list的方法
Jul 02 Python
Python shutil模块用法实例分析
Oct 02 Python
pycharm中导入模块错误时提示Try to run this command from the system terminal
Mar 26 Python
Pycharm安装python库的方法
Nov 24 Python
解决python 输出到csv 出现多空行的情况
Mar 24 Python
PyQt5实现多张图片显示并滚动
Jun 11 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
JoshChen_web格式编码UTF8-无BOM的小细节分析
2013/08/16 PHP
php实现微信公众平台账号自定义菜单类
2015/10/11 PHP
PHP实现文件上传和多文件上传
2015/12/24 PHP
PHP读取大文件末尾N行的高效方法推荐
2016/06/03 PHP
PHP中字符与字节的区别及字符串与字节转换示例
2016/10/15 PHP
js绑定事件this指向发生改变的问题解决方法
2013/04/23 Javascript
jquery实现弹出层完美居中效果
2014/03/03 Javascript
用js格式化金额可设置保留的小数位数
2014/05/09 Javascript
laypage分页控件使用实例详解
2016/05/19 Javascript
全面了解addEventListener和on的区别
2016/07/14 Javascript
javascript 将共享属性迁移到原型中去的实现方法
2016/08/31 Javascript
js实现省市级联效果分享
2017/08/10 Javascript
jQuery实现的文字逐行向上间歇滚动效果示例
2017/09/06 jQuery
React styled-components设置组件属性的方法
2018/08/07 Javascript
Nodejs使用archiver-zip-encrypted库加密压缩文件时报错(解决方案)
2019/11/18 NodeJs
webpack打包优化的几个方法总结
2020/02/10 Javascript
JS数组的高级使用方法示例小结
2020/03/14 Javascript
python脚本替换指定行实现步骤
2017/07/11 Python
python梯度下降法的简单示例
2018/08/31 Python
解决Pycharm界面的子窗口不见了的问题
2019/01/17 Python
python实现比较类的两个instance(对象)是否相等的方法分析
2019/06/26 Python
pytorch模型预测结果与ndarray互转方式
2020/01/15 Python
pycharm通过anaconda安装pyqt5的教程
2020/03/24 Python
python pip如何手动安装二进制包
2020/09/30 Python
CSS3+font字体文件实现圆形半透明菜单具体步骤(图解)
2013/06/03 HTML / CSS
简短大学毕业感言
2014/01/18 职场文书
2015年入党决心书
2015/02/05 职场文书
军训阅兵新闻稿
2015/07/17 职场文书
读完《骆驼祥子》的观后感!
2019/07/05 职场文书
导游词之苏州盘门景区
2019/11/12 职场文书
导游词之江苏同里古镇
2019/11/18 职场文书
Python+Appium新手教程
2021/04/17 Python
将Python代码打包成.exe可执行文件的完整步骤
2021/05/12 Python
MySQL 重命名表的操作方法及注意事项
2021/05/21 MySQL
Redis集群新增、删除节点以及动态增加内存的方法
2021/09/04 Redis
使用compose函数优化代码提高可读性及扩展性
2022/06/16 Javascript