TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取京东价格分析京东商品价格走势
Jan 09 Python
python中global与nonlocal比较
Nov 21 Python
浅要分析Python程序与C程序的结合使用
Apr 07 Python
Python下线程之间的共享和释放示例
May 04 Python
python实现计算倒数的方法
Jul 11 Python
Go语言基于Socket编写服务器端与客户端通信的实例
Feb 19 Python
python使用正则表达式匹配字符串开头并打印示例
Jan 11 Python
利用python获取当前日期前后N天或N月日期的方法示例
Jul 30 Python
python按综合、销量排序抓取100页的淘宝商品列表信息
Feb 24 Python
Django2.1集成xadmin管理后台所遇到的错误集锦(填坑)
Dec 20 Python
详解Python Qt的窗体开发的基本操作
Jul 14 Python
Django DRF认证组件流程实现原理详解
Aug 17 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
php项目打包方法
2008/02/18 PHP
Zend framework处理一个http请求的流程分析
2010/02/08 PHP
PHP使用内置dir类实现目录遍历删除
2015/03/31 PHP
thinkPHP下的widget扩展用法实例分析
2015/12/26 PHP
浅析PHP中的 inet_pton 网络函数
2019/12/16 PHP
从零开始学习jQuery (三) 管理jQuery包装集
2011/02/23 Javascript
JavaScript 用Node.js写Shell脚本[译]
2012/09/20 Javascript
jQuery的控件及事件(输入控件及回车事件)使用示例
2013/07/25 Javascript
js中事件的处理与浏览器对象示例介绍
2013/11/29 Javascript
javascript框架设计之浏览器的嗅探和特征侦测
2015/06/23 Javascript
JS排序方法(sort,bubble,select,insert)代码汇总
2016/01/30 Javascript
JS+CSS实现的漂亮渐变背景特效代码(6个渐变效果)
2016/03/25 Javascript
浅谈Sublime Text 3运行JavaScript控制台
2016/06/06 Javascript
深入理解逻辑表达式的用法 与或非的用法
2016/06/06 Javascript
javascript滚轮控制模拟滚动条
2016/10/19 Javascript
webpack+vue中使用别名路径引用静态图片地址
2017/11/20 Javascript
es6新特性之 class 基本用法解析
2018/05/05 Javascript
vue项目打包部署到服务器的方法示例
2018/08/27 Javascript
微信小程序左滑删除功能开发案例详解
2018/11/12 Javascript
nodejs实现百度舆情接口应用示例
2020/02/07 NodeJs
Windows上安装tensorflow  详细教程(图文详解)
2020/02/04 Python
python使用openpyxl操作excel的方法步骤
2020/05/28 Python
Pytorch实现将模型的所有参数的梯度清0
2020/06/24 Python
Python tkinter制作单机五子棋游戏
2020/09/14 Python
HTML5之SVG 2D入门7—SVG元素的重用与引用
2013/01/30 HTML / CSS
Puritan’s Pride(普丽普莱)官方网站:美国最大最全的保健品公司之一
2016/10/23 全球购物
小学教师办公室制度
2014/02/03 职场文书
《老王》教学反思
2014/02/23 职场文书
某某同志考察材料
2014/05/28 职场文书
消防宣传口号
2014/06/16 职场文书
学风建设演讲稿
2014/09/12 职场文书
2014年护理工作总结范文
2014/11/14 职场文书
2015年健康教育工作总结
2015/04/10 职场文书
公司财务管理制度
2015/08/04 职场文书
Golang连接并操作MySQL
2022/04/14 MySQL
uniapp开发打包多端应用完整方法指南
2022/12/24 Javascript