TensorFlow——Checkpoint为模型添加检查点的实例


Posted in Python onJanuary 21, 2020

1.检查点

保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。

这种在训练中保存模型,习惯上称之为保存检查点。

2.添加保存点

通过添加检查点,可以生成载入检查点文件,并能够指定生成检查文件的个数,例如使用saver的另一个参数——max_to_keep=1,表明最多只保存一个检查点文件,在保存时使用如下的代码传入迭代次数。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

plt.plot(train_x, train_y, 'r.')
plt.grid(True)
plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 20
display_step = 2


saver = tf.train.Saver(max_to_keep=15)
savedir = "model/"


if __name__ == '__main__':
 with tf.Session() as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})

   saver.save(sess, savedir + "linear.cpkt", global_step=epoch)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())
  saver.restore(sess2, savedir + "linear.cpkt-" + str(load_epoch))
  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

在上述的代码中,我们使用saver.save(sess, savedir + "linear.cpkt", global_step=epoch)将训练的参数传入检查点进行保存,saver = tf.train.Saver(max_to_keep=1)表示只保存一个文件,这样在训练过程中得到的新的模型就会覆盖以前的模型。

cpkt = tf.train.get_checkpoint_state(savedir)
if cpkt and cpkt.model_checkpoint_path:
saver.restore(sess2, cpkt.model_checkpoint_path)

kpt = tf.train.latest_checkpoint(savedir)
saver.restore(sess2, kpt)

上述的两种方法也可以对checkpoint文件进行加载,tf.train.latest_checkpoint(savedir)为加载最后的检查点文件。这种方式,我们可以通过保存指定训练次数的检查点,比如保存5的倍数次保存一下检查点。

3.简便保存检查点

我们还可以用更加简单的方法进行检查点的保存,tf.train.MonitoredTrainingSession()函数,该函数可以直接实现保存载入检查点模型的文件,与前面的方法不同的是,它是按照训练时间来保存检查点的,可以通过指定save_checkpoint_secs参数的具体秒数,设置多久保存一次检查点。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

train_x = np.linspace(-5, 3, 50)
train_y = train_x * 5 + 10 + np.random.random(50) * 10 - 5

# plt.plot(train_x, train_y, 'r.')
# plt.grid(True)
# plt.show()

tf.reset_default_graph()

X = tf.placeholder(dtype=tf.float32)
Y = tf.placeholder(dtype=tf.float32)

w = tf.Variable(tf.random.truncated_normal([1]), name='Weight')
b = tf.Variable(tf.random.truncated_normal([1]), name='bias')

z = tf.multiply(X, w) + b

cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

init = tf.global_variables_initializer()

training_epochs = 30
display_step = 2


global_step = tf.train.get_or_create_global_step()

step = tf.assign_add(global_step, 1)

saver = tf.train.Saver()

savedir = "check-point/"

if __name__ == '__main__':
 with tf.train.MonitoredTrainingSession(checkpoint_dir=savedir + 'linear.cpkt', save_checkpoint_secs=5) as sess:
  sess.run(init)
  loss_list = []
  for epoch in range(training_epochs):
   sess.run(global_step)
   for (x, y) in zip(train_x, train_y):
    sess.run(optimizer, feed_dict={X: x, Y: y})

   if epoch % display_step == 0:
    loss = sess.run(cost, feed_dict={X: x, Y: y})
    loss_list.append(loss)
    print('Iter: ', epoch, ' Loss: ', loss)

   w_, b_ = sess.run([w, b], feed_dict={X: x, Y: y})
   sess.run(step)

  print(" Finished ")
  print("W: ", w_, " b: ", b_, " loss: ", loss)
  plt.plot(train_x, train_x * w_ + b_, 'g-', train_x, train_y, 'r.')
  plt.grid(True)
  plt.show()

 load_epoch = 10

 with tf.Session() as sess2:
  sess2.run(tf.global_variables_initializer())

  # saver.restore(sess2, savedir + 'linear.cpkt-' + str(load_epoch))

  # cpkt = tf.train.get_checkpoint_state(savedir)
  # if cpkt and cpkt.model_checkpoint_path:
  #  saver.restore(sess2, cpkt.model_checkpoint_path)
  #
  kpt = tf.train.latest_checkpoint(savedir + 'linear.cpkt')

  saver.restore(sess2, kpt)

  print(sess2.run([w, b], feed_dict={X: train_x, Y: train_y}))

上述的代码中,我们设置了没训练了5秒中之后,就保存一次检查点,它默认的保存时间间隔是10分钟,这种按照时间的保存模式更适合使用大型数据集训练复杂模型的情况,注意在使用上述的方法时,要定义global_step变量,在训练完一个批次或者一个样本之后,要将其进行加1的操作,否则将会报错。

TensorFlow——Checkpoint为模型添加检查点的实例

以上这篇TensorFlow——Checkpoint为模型添加检查点的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解决字典中的值是列表问题的方法
Mar 04 Python
使用Python发送邮件附件以定时备份MySQL的教程
Apr 25 Python
wxPython中listbox用法实例详解
Jun 01 Python
go和python变量赋值遇到的一个问题
Aug 31 Python
Django 路由系统URLconf的使用
Oct 11 Python
Python3.6.2调用ffmpeg的方法
Jan 10 Python
Django admin model 汉化显示文字的实现方法
Aug 12 Python
利用rest framework搭建Django API过程解析
Aug 31 Python
python 实现在无序数组中找到中位数方法
Mar 03 Python
python+OpenCV实现图像拼接
Mar 05 Python
Python基于stuck实现scoket文件传输
Apr 02 Python
简述python四种分词工具,盘点哪个更好用?
Apr 13 Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
python爬取本站电子书信息并入库的实现代码
Jan 20 #Python
浅谈Tensorflow 动态双向RNN的输出问题
Jan 20 #Python
You might like
动漫定律:眯眯眼都是怪物!这些角色狠话不多~
2020/03/03 日漫
PHP+FastCGI+Nginx配置PHP运行环境
2014/08/07 PHP
前端必学之PHP语法基础
2016/01/01 PHP
PHP使用stream_context_create()模拟POST/GET请求的方法
2016/04/02 PHP
php封装单文件上传到数据库(路径)
2017/10/15 PHP
jQuery 页面载入进度条实现代码
2009/02/08 Javascript
JavaScript 学习小结(适合新手参考)
2009/07/30 Javascript
JQuery处理json与ajax返回JSON实例代码
2014/01/03 Javascript
JavaScript实现下拉列表框数据增加、删除、上下排序的方法
2015/08/11 Javascript
Jquery修改image的src属性,图片不加载问题的解决方法
2016/05/17 Javascript
jQuery 控制文本框自动缩小字体填充
2017/06/16 jQuery
JS中数据结构之栈
2019/01/01 Javascript
layui在form表单页面通过Validform加入简单验证的方法
2019/09/06 Javascript
vue中重定向redirect:‘/index‘,不显示问题、跳转出错的完美解决
2020/09/28 Javascript
JavaScript缓动动画函数的封装方法
2020/11/25 Javascript
python二维列表一维列表的互相转换实例
2018/07/02 Python
利用python实现简易版的贪吃蛇游戏(面向python小白)
2018/12/30 Python
Python超越函数积分运算以及绘图实现代码
2019/11/20 Python
python 服务器运行代码报错ModuleNotFoundError的解决办法
2020/09/16 Python
深入研究HTML5实现图片压缩上传功能
2016/03/25 HTML / CSS
HTML5 canvas基本绘图之填充样式实现
2016/06/27 HTML / CSS
印尼在线购买隐形眼镜网站:Lensza.co.id
2019/04/27 全球购物
Hanky Panky官方网站:内衣和睡衣
2019/07/25 全球购物
俄罗斯女装店:12storeez
2019/10/25 全球购物
询价采购方案
2014/06/09 职场文书
学校工作推荐信范文
2014/07/11 职场文书
2014年酒店工作总结范文
2014/11/17 职场文书
老公婚前保证书
2015/02/28 职场文书
关于运动会的宣传稿
2015/07/23 职场文书
《赵州桥》教学反思
2016/02/17 职场文书
高中语文教材(文学文化常识大全一)
2019/08/13 职场文书
酒店工程部的岗位职责汇总大全
2019/10/23 职场文书
Nginx反爬虫策略,防止UA抓取网站
2021/03/31 Servers
JPA 通过Specification如何实现复杂查询
2021/11/23 Java/Android
Spring Data JPA框架Repository自定义实现
2022/04/28 Java/Android
win10如何开启ahci模式?win10开启ahci模式详细操作教程
2022/07/23 数码科技