TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于multiprocessing的多进程创建方法
Jun 04 Python
关于Django外键赋值问题详解
Aug 13 Python
利用pyinstaller将py文件打包为exe的方法
May 14 Python
Tensorflow卷积神经网络实例进阶
May 24 Python
python3连接MySQL数据库实例详解
May 24 Python
python递归函数绘制分形树的方法
Jun 22 Python
python 返回列表中某个值的索引方法
Nov 07 Python
解决python3中cv2读取中文路径的问题
Dec 05 Python
对Python w和w+权限的区别详解
Jan 23 Python
详解Python Opencv和PIL读取图像文件的差别
Dec 27 Python
python 判断txt每行内容中是否包含子串并重新写入保存的实例
Mar 12 Python
Python函数调用追踪实现代码
Nov 27 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
php 获取完整url地址
2008/12/20 PHP
php 随机生成10位字符代码
2009/03/26 PHP
PHP判断远程url是否有效的几种方法小结
2011/10/08 PHP
javascript 放大镜 v1.0 基于Yui2 实现的放大镜效果
2010/03/08 Javascript
JavaScript中获取未知对象属性的代码
2011/04/27 Javascript
jQuery+css+html实现页面遮罩弹出框
2013/03/21 Javascript
原生js实现改变随意改变div属性style的名称和值的结果
2013/09/26 Javascript
JS控制网页动态生成任意行列数表格的方法
2015/03/09 Javascript
javascript实现俄罗斯方块游戏的思路和方法
2015/04/27 Javascript
js控制div弹出层实现方法
2015/05/11 Javascript
BootStrap glyphicon图标无法显示的解决方法
2016/09/06 Javascript
JS 获取HTML标签内的子节点的方法
2016/09/21 Javascript
bootstrap输入框组使用方法
2017/02/07 Javascript
vue中的非父子间的通讯问题简单的实例代码
2017/07/19 Javascript
使用Vuex实现一个笔记应用的方法
2018/03/13 Javascript
Node.js如何优雅的封装一个实用函数的npm包的方法
2019/04/29 Javascript
了解javascript中let和var及const关键字的区别
2019/05/24 Javascript
JavaScript获取当前url路径过程解析
2019/12/27 Javascript
对于Python异常处理慎用“except:pass”建议
2015/04/02 Python
深入浅析Python中list的复制及深拷贝与浅拷贝
2018/09/03 Python
基于tensorflow for循环 while循环案例
2020/06/30 Python
Django框架实现在线考试系统的示例代码
2020/11/30 Python
HTML5网页录音和上传到服务器支持PC、Android,支持IOS微信功能
2019/04/26 HTML / CSS
程序运行正确, 但退出时却"core dump"了,怎么回事
2014/02/19 面试题
大学本科毕业生求职简历的自我评价
2013/10/09 职场文书
微型企业创业投资计划书
2014/01/10 职场文书
商场活动策划方案
2014/01/24 职场文书
2014年党的群众路线活动个人整改措施
2014/10/28 职场文书
领导干部群众路线对照检查材料
2014/11/05 职场文书
保密工作整改情况汇报
2014/11/06 职场文书
英文感谢信范文
2015/01/21 职场文书
好员工观后感
2015/06/17 职场文书
2016春季幼儿园开学寄语
2015/12/03 职场文书
三好学生竞选稿范文
2019/08/21 职场文书
写一个Python脚本自动爬取Bilibili小视频
2021/04/24 Python
JS前端宏任务微任务及Event Loop使用详解
2022/07/23 Javascript