TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中关于时间和日期函数的常用计算总结(time和datatime)
Mar 08 Python
详解Python Socket网络编程
Jan 05 Python
详解Python中的__getitem__方法与slice对象的切片操作
Jun 27 Python
Python语言实现百度语音识别API的使用实例
Dec 13 Python
python tkinter界面居中显示的方法
Oct 11 Python
python找出完数的方法
Nov 12 Python
python运行时强制刷新缓冲区的方法
Jan 14 Python
python实现两个经纬度点之间的距离和方位角的方法
Jul 05 Python
Django框架HttpRequest对象用法实例分析
Nov 01 Python
基于pytorch的lstm参数使用详解
Jan 14 Python
python IDLE添加行号显示教程
Apr 25 Python
cookies应对python反爬虫知识点详解
Nov 25 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
全国FM电台频率大全 - 18 湖南省
2020/03/11 无线电
针对初学PHP者的疑难问答(2)
2006/10/09 PHP
PHP 面向对象 PHP5 中的常量
2010/05/05 PHP
9个PHP开发常用功能函数小结
2011/07/15 PHP
php cookie使用方法学习笔记分享
2013/11/07 PHP
Thinkphp调用Image类生成缩略图的方法
2015/03/07 PHP
php+redis实现多台服务器内网存储session并读取示例
2017/01/12 PHP
利用PHP_XLSXWriter代替PHPExcel的方法示例
2017/07/16 PHP
ScrollDown的基本操作示例
2013/06/09 Javascript
JS验证日期的格式YYYY-mm-dd 具体实现
2013/06/29 Javascript
jQuery插件kinMaxShow扩展效果用法实例
2015/05/04 Javascript
jQuery实现模拟marquee标签效果
2015/07/14 Javascript
Bootstrap基本样式学习笔记之标签(5)
2016/12/07 Javascript
JS实现Ajax的方法分析
2016/12/20 Javascript
用jQuery旋转插件jqueryrotate制作转盘抽奖
2017/02/10 Javascript
详解vue-router 2.0 常用基础知识点之导航钩子
2017/05/10 Javascript
js实现微信/QQ直接跳转到支付宝APP打开口令领红包功能
2018/01/09 Javascript
详解angular路由高亮之RouterLinkActive
2018/04/28 Javascript
微信小程序位置授权处理方法
2019/06/13 Javascript
微信小程序之下拉列表实现方法解析(附完整源码)
2019/08/23 Javascript
Vue的双向数据绑定实现原理解析
2020/02/17 Javascript
[52:03]Secret vs VG 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
名片管理系统python版
2018/01/11 Python
在Pycharm中自动添加时间日期作者等信息的方法
2019/01/16 Python
详解python pandas 分组统计的方法
2019/07/30 Python
Python 将json序列化后的字符串转换成字典(推荐)
2020/01/06 Python
10分钟理解CSS3 FlexBox弹性布局
2018/12/20 HTML / CSS
营业员个人总结的自我评价
2013/10/25 职场文书
学期自我鉴定
2013/11/04 职场文书
聘用意向书范本
2014/04/01 职场文书
村级换届选举方案
2014/05/10 职场文书
离婚协议书怎么写
2014/09/12 职场文书
房产公证委托书范本
2014/09/20 职场文书
2014党员批评和自我批评思想汇报
2014/09/21 职场文书
升职感谢领导的话语及升职感谢信
2019/06/24 职场文书
jQuery ajax - getScript() 方法和getJSON方法
2021/05/14 jQuery