TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中动态检测编码chardet的使用教程
Jul 06 Python
python读取与写入csv格式文件的示例代码
Dec 16 Python
Python 记录日志的灵活性和可配置性介绍
Feb 27 Python
对python中array.sum(axis=?)的用法介绍
Jun 28 Python
python+selenium实现自动化百度搜索关键词
Jun 03 Python
Django Rest framework认证组件详细用法
Jul 25 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 Python
详解字符串在Python内部是如何省内存的
Feb 03 Python
python实现低通滤波器代码
Feb 26 Python
Python实现数字的格式化输出
Aug 01 Python
详解如何修改python中字典的键和值
Sep 29 Python
详解Python中*args和**kwargs的使用
Apr 07 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
一个图片地址分解程序(用于PHP小偷程序)
2014/08/23 PHP
PHP基于工厂模式实现的计算器实例
2015/07/16 PHP
PHP操作mysql数据库分表的方法
2016/06/09 PHP
PHP消息队列实现及应用详解【队列处理订单系统和配送系统】
2019/05/20 PHP
Jquery实战_读书笔记1—选择jQuery
2010/01/22 Javascript
jquery的index方法实现tab效果
2011/02/16 Javascript
jQuery插件 selectToSelect使用方法
2013/10/02 Javascript
js时间日期格式化封装函数
2014/12/02 Javascript
javascript实现日期按月份加减
2015/05/15 Javascript
使用JavaScript脚本无法直接改变Asp.net中Checkbox控件的Enable属性的解决方法
2015/09/16 Javascript
微信jssdk在iframe页面失效问题的解决措施
2016/03/03 Javascript
Jquery实现的简单轮播效果【附实例】
2016/04/19 Javascript
使用node.js中的Buffer类处理二进制数据的方法
2016/11/26 Javascript
Angular1.x复杂指令实例详解
2017/03/01 Javascript
详解React Native网络请求fetch简单封装
2017/08/10 Javascript
针对Vue路由history模式下Nginx后台配置操作
2020/10/22 Javascript
Vue 打包的静态文件不能直接运行的原因及解决办法
2020/11/19 Vue.js
[57:55]完美世界DOTA2联赛PWL S3 Magma vs Phoenix 第二场 12.12
2020/12/16 DOTA
Python加pyGame实现的简单拼图游戏实例
2015/05/15 Python
Jupyter安装nbextensions,启动提示没有nbextensions库
2020/04/23 Python
pycharm在调试python时执行其他语句的方法
2018/11/29 Python
Python实现的在特定目录下导入模块功能分析
2019/02/11 Python
PyTorch之图像和Tensor填充的实例
2019/08/18 Python
python 实现rolling和apply函数的向下取值操作
2020/06/08 Python
国际书籍零售商:Wordery
2017/11/01 全球购物
Java多态性的定义以及类型
2014/09/16 面试题
别名指示符是什么
2012/10/08 面试题
业务部主管岗位职责
2014/01/29 职场文书
大学同学十年聚会感言
2014/02/21 职场文书
团拜会策划方案
2014/06/07 职场文书
公司证明怎么写
2014/09/22 职场文书
毕业实习指导教师评语
2014/12/31 职场文书
夫妻吵架保证书
2015/05/08 职场文书
详解MySQL多版本并发控制机制(MVCC)源码
2021/06/23 MySQL
Python实现文字pdf转换图片pdf效果
2022/04/03 Python
项目中Nginx多级代理是如何获取客户端的真实IP地址
2022/05/30 Servers