TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用正则搜索字符串或文件中的浮点数代码实例
Jul 11 Python
python各种语言间时间的转化实现代码
Mar 23 Python
深度定制Python的Flask框架开发环境的一些技巧总结
Jul 12 Python
python获取文件路径、文件名、后缀名的实例
Apr 23 Python
python opencv判断图像是否为空的实例
Jan 26 Python
Python提取特定时间段内数据的方法实例
Apr 01 Python
Python实现获取系统临时目录及临时文件的方法示例
Jun 26 Python
使用Django搭建一个基金模拟交易系统教程
Nov 18 Python
Python 过滤错误log并导出的实例
Dec 26 Python
python标识符命名规范原理解析
Jan 10 Python
Python ArgumentParse的subparser用法说明
Apr 20 Python
使用sublime text3搭建Python编辑环境的实现
Jan 12 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
PHP 显示客户端IP与服务器IP的代码
2010/10/12 PHP
PHP优化之批量操作MySQL实例分析
2020/04/23 PHP
javascript 嵌套的函数(作用域链)
2010/03/15 Javascript
jQuery Clone Bug解决代码
2010/12/22 Javascript
js 使用form表单select类实现级联菜单效果
2012/12/19 Javascript
JavaScript 中的日期和时间及表示标准介绍
2013/08/21 Javascript
JQuery结合CSS操作打印样式的方法
2013/12/24 Javascript
javascript实现json页面分页实例代码
2014/02/20 Javascript
JQuery选择器绑定事件及修改内容的方法
2015/01/23 Javascript
jQuery实现带滚动导航效果的全屏滚动相册实例
2015/06/19 Javascript
JS跨域解决方案之使用CORS实现跨域
2016/04/14 Javascript
javascript学习之json入门
2016/12/22 Javascript
JavaScript优化以及前段开发小技巧
2017/02/02 Javascript
bootstrap table表格使用方法详解
2017/04/26 Javascript
使用canvas实现一个vue弹幕组件功能
2018/11/30 Javascript
vue 表单验证按钮事件交由父组件触发的方法
2018/12/17 Javascript
vue 兄弟组件的信息传递的方法实例详解
2019/08/30 Javascript
JQuery表单元素取值赋值方法总结
2020/05/12 jQuery
Python实现的异步代理爬虫及代理池
2017/03/17 Python
用pycharm开发django项目示例代码
2019/06/13 Python
Python 运行.py文件和交互式运行代码的区别详解
2019/07/02 Python
django 捕获异常和日志系统过程详解
2019/07/18 Python
使用OpenCV circle函数图像上画圆的示例代码
2019/12/27 Python
TensorFlow 显存使用机制详解
2020/02/03 Python
Matplotlib 折线图plot()所有用法详解
2020/07/28 Python
Python如何把字典写入到CSV文件的方法示例
2020/08/23 Python
Python实现小黑屋游戏的完整实例
2021/01/06 Python
美国大城市最热门旅游景点门票:CityPASS
2016/12/16 全球购物
NBA欧洲商店(西班牙):NBA Europe Store ES
2019/04/16 全球购物
eBay奥地利站:eBay.at
2019/07/24 全球购物
中专毕业生自我鉴定
2014/02/02 职场文书
个性发展自我评价
2014/02/11 职场文书
英语教学随笔感言
2014/02/20 职场文书
2015年计划生育协会工作总结
2015/05/13 职场文书
南阳市白酒市场的调查报告
2019/11/08 职场文书
python关于集合的知识案例详解
2021/05/30 Python