TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python实现的扫雷游戏实例代码
Aug 01 Python
详解Python中DOM方法的动态性
Apr 11 Python
简单了解python模块概念
Jan 11 Python
python模块之paramiko实例代码
Jan 31 Python
查看Django和flask版本的方法
May 14 Python
pycharm 将django中多个app放到同个文件夹apps的处理方法
May 30 Python
python生成以及打开json、csv和txt文件的实例
Nov 16 Python
Python pygame绘制文字制作滚动文字过程解析
Dec 12 Python
python pyqtgraph 保存图片到本地的实例
Mar 14 Python
在python中修改.properties文件的操作
Apr 08 Python
python中os.remove()用法及注意事项
Jan 31 Python
pytorch 中autograd.grad()函数的用法说明
May 12 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
《超神学院》霸气归来, 天使彦上演维多利亚的秘密
2020/03/02 国漫
改造一台复古桌面收音机
2021/03/02 无线电
php 5.3.5安装memcache注意事项小结
2011/04/12 PHP
php中将时间差转换为字符串提示的实现代码
2011/08/08 PHP
Thinkphp 在api开发中异常返回依然是html的解决方式
2019/10/16 PHP
Yii使用EasyWechat实现小程序获取用户的openID的方法
2020/04/29 PHP
通过ifame指向的页面高度调整iframe的高度
2006/10/05 Javascript
JavaScript匿名函数之模仿块级作用域
2015/12/12 Javascript
jQuery+Ajax实现无刷新操作
2016/01/04 Javascript
jquery心形点赞关注效果的简单实现
2016/11/14 Javascript
bootstrap日历插件datetimepicker使用方法
2016/12/14 Javascript
老生常谈js中的MVC
2017/07/25 Javascript
JavaScript定时器setTimeout()和setInterval()详解
2017/08/18 Javascript
jQuery实现评论模块
2020/08/19 jQuery
vue实现简单加法计算器
2020/10/22 Javascript
[10:53]2018DOTA2国际邀请赛寻真——EG
2018/08/11 DOTA
举例讲解Django中数据模型访问外键值的方法
2015/07/21 Python
python多线程方式执行多个bat代码
2016/06/07 Python
win系统下为Python3.5安装flask-mongoengine 库
2016/12/20 Python
Python简单实现Base64编码和解码的方法
2017/04/29 Python
解决Django的request.POST获取不到内容的问题
2018/05/28 Python
详解django.contirb.auth-认证
2018/07/16 Python
Python面向对象之反射/自省机制实例分析
2018/08/24 Python
Python判断一个三位数是否为水仙花数的示例
2018/11/13 Python
opencv实现图片模糊和锐化操作
2018/11/19 Python
python3定位并识别图片验证码实现自动登录功能
2021/01/29 Python
html5新特性与用法大全
2018/09/13 HTML / CSS
HTML5图片层叠的实现示例
2020/07/07 HTML / CSS
HTML5中外部浏览器唤起微信分享功能的代码
2020/09/15 HTML / CSS
芭比波朗加拿大官方网站:Bobbi Brown Cosmetics CA
2020/11/05 全球购物
通用C#笔试题附答案
2016/11/26 面试题
2013年保送生自荐信格式
2013/11/20 职场文书
2015年教师师德师风承诺书
2015/04/28 职场文书
redis 查看所有的key方式
2021/05/07 Redis
Python3 类型标注支持操作
2021/06/02 Python
Python编解码问题及文本文件处理方法详解
2021/06/20 Python