TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之关于类的初步认识
Oct 11 Python
Python之os操作方法(详解)
Jun 15 Python
python 获取指定文件夹下所有文件名称并写入列表的实例
Apr 23 Python
python 实现在txt指定行追加文本的方法
Apr 29 Python
记一次python 内存泄漏问题及解决过程
Nov 29 Python
Django中自定义admin Xadmin的实现代码
Aug 09 Python
python flask搭建web应用教程
Nov 19 Python
python多线程实现代码(模拟银行服务操作流程)
Jan 13 Python
Python virtualenv虚拟环境实现过程解析
Apr 18 Python
Python如何使用神经网络进行简单文本分类
Feb 25 Python
python3判断IP地址的方法
Mar 04 Python
Python 中的Sympy详细使用
Aug 07 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
浅析PHP绘图技术
2013/07/03 PHP
PHP时间格式控制符对照表分享
2013/07/23 PHP
php进程间通讯实例分析
2016/07/11 PHP
smarty模板的使用方法实例分析
2019/09/18 PHP
解决javascript:window.close()在chrome,Firefox下失效的问题
2013/05/07 Javascript
JavaScript实现数字数组正序排列的方法
2015/04/06 Javascript
javascript中call apply 的应用场景
2015/04/16 Javascript
多种JQuery循环滚动文字图片效果代码
2020/06/23 Javascript
正则表达式(语法篇推荐)
2016/06/24 Javascript
javascript超过容器后显示省略号效果的方法(兼容一行或者多行)
2016/07/14 Javascript
JS编写函数实现对身份证号码最后一位的验证功能
2016/12/29 Javascript
vue子组件使用自定义事件向父组件传递数据
2017/05/27 Javascript
JavaScript实用代码小技巧
2018/08/23 Javascript
layui时间控件选择时间范围的实现方法
2019/09/28 Javascript
js中script的上下放置区别,Dom的增删改创建操作实例分析
2019/12/16 Javascript
JS实现导航栏楼层特效
2020/01/01 Javascript
[01:32]寻找你心中的那团火 DOTA2 TI9火焰传递活动今日开启
2019/05/16 DOTA
王纯业的Python学习笔记 下载
2007/02/10 Python
python求列表交集的方法汇总
2014/11/10 Python
Python自动登录126邮箱的方法
2015/07/10 Python
使用Python下载歌词并嵌入歌曲文件中的实现代码
2015/11/13 Python
python代码 输入数字使其反向输出的方法
2018/12/22 Python
python使用SQLAlchemy操作MySQL
2020/01/02 Python
Pytorch训练过程出现nan的解决方式
2020/01/02 Python
AmazeUI的下载配置与Helloworld的实现
2020/08/19 HTML / CSS
日本非常有名的内衣丝袜品牌:GUNZE
2017/01/06 全球购物
DOM和JQuery对象有什么区别
2016/11/11 面试题
大学生学习生活的自我评价
2013/11/01 职场文书
教育学专业毕业生的自我评价
2013/11/21 职场文书
医院护士的求职信范文
2013/12/26 职场文书
电厂厂长岗位职责
2014/01/02 职场文书
工程质量月活动方案
2014/02/19 职场文书
硕士生找工作求职信
2014/07/05 职场文书
导游词之广西漓江
2019/11/02 职场文书
python爬虫selenium模块详解
2021/03/30 Python
解决IDEA翻译插件Translation报错更新TTK失败不能使用
2022/04/24 Python