TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 用Redis简单实现分布式爬虫的方法
Nov 23 Python
python在ubuntu中的几种安装方法(小结)
Dec 08 Python
Python获取指定文件夹下的文件名的方法
Feb 06 Python
对Python中Iterator和Iterable的区别详解
Oct 18 Python
在Python中定义一个常量的方法
Nov 10 Python
在pycharm中使用git版本管理以及同步github的方法
Jan 16 Python
django配置连接数据库及原生sql语句的使用方法
Mar 03 Python
python模拟菜刀反弹shell绕过限制【推荐】
Jun 25 Python
Python编写带选项的命令行程序方法
Aug 13 Python
python设置代理和添加镜像源的方法
Feb 14 Python
利用jupyter网页版本进行python函数查询方式
Apr 14 Python
python selenium xpath定位操作
Sep 01 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
PHP 的 __FILE__ 常量
2007/01/15 PHP
PHP数据库调用类调用实例(详细注释)
2012/07/12 PHP
Php中用PDO查询Mysql来避免SQL注入风险的方法
2013/04/25 PHP
简单了解将WordPress中的工具栏移到底部的小技巧
2015/12/31 PHP
javascript奇异的arguments分析
2010/10/20 Javascript
jquery.fileEveryWhere.js 一个跨浏览器的file显示插件
2011/10/24 Javascript
2012年开发人员的16款新鲜的jquery插件体验分享
2012/12/28 Javascript
JavaScript的模块化:封装(闭包),继承(原型) 介绍
2013/07/22 Javascript
JSON传递bool类型数据的处理方式介绍
2013/09/18 Javascript
js取消单选按钮选中示例代码
2013/11/14 Javascript
jquery实现显示已选用户
2014/07/21 Javascript
js的touch事件的实际引用
2014/10/13 Javascript
网站基于flash实现的Banner图切换效果代码
2014/10/14 Javascript
js进行表单验证实例分析
2015/02/10 Javascript
javascript中checkbox使用方法实例演示
2015/11/19 Javascript
js实现添加可信站点、修改activex安全设置,禁用弹出窗口阻止程序
2016/08/17 Javascript
微信小程序手机号码验证功能的实例代码
2018/08/28 Javascript
详解vue-cli 3.0 build包太大导致首屏过长的解决方案
2018/11/10 Javascript
微信小程序自定义组件传值 页面和组件相互传数据操作示例
2019/05/05 Javascript
JS中使用react-tooltip插件实现鼠标悬浮显示框
2019/05/15 Javascript
vue2路由基本用法实例分析
2020/03/06 Javascript
深入理解python中的浅拷贝和深拷贝
2016/05/30 Python
Python编程中NotImplementedError的使用方法
2018/04/21 Python
Python+OpenCv制作证件图片生成器的操作方法
2019/08/21 Python
Python获取统计自己的qq群成员信息的方法
2019/11/15 Python
python通过移动端访问查看电脑界面
2020/01/06 Python
Python sql注入 过滤字符串的非法字符实例
2020/04/03 Python
python文件读取失败怎么处理
2020/06/23 Python
pycharm实现猜数游戏
2020/12/07 Python
Wedgwood美国官网:英国骨瓷,精美礼品及家居装饰
2018/02/17 全球购物
CHARLES & KEITH澳大利亚官网:新加坡时尚品牌
2019/01/22 全球购物
系统管理员的职责包括那些?管理的对象是什么?
2013/01/18 面试题
三年级科学教学反思
2014/01/29 职场文书
爱心捐助倡议书
2014/05/19 职场文书
Kubernetes关键组件与结构组成介绍
2022/03/31 Servers
如何在Python中妥善使用进度条详解
2022/04/05 Python