TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python分析nignx访问日志脚本分享
Feb 26 Python
python实现定时同步本机与北京时间的方法
Mar 24 Python
Python实现保证只能运行一个脚本实例
Jun 24 Python
Python实现抢购IPhone手机
Feb 07 Python
[原创]windows下Anaconda的安装与配置正解(Anaconda入门教程)
Apr 05 Python
Python将DataFrame的某一列作为index的方法
Apr 08 Python
python3.7 sys模块的具体使用
Jul 22 Python
详解python 利用echarts画地图(热力图)(世界地图,省市地图,区县地图)
Aug 06 Python
python如何使用Redis构建分布式锁
Jan 16 Python
python模式 工厂模式原理及实例详解
Feb 11 Python
详解Python常用的魔法方法
Jun 03 Python
Python使用Opencv打开笔记本电脑摄像头报错解问题及解决
Jun 21 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
php 无限级缓存的类的扩展
2009/03/16 PHP
php session处理的定制
2009/03/16 PHP
浅析SVN常见问题及解决方法
2013/06/21 PHP
php编写简单的文章发布程序
2015/06/18 PHP
PHP中ID设置自增后不连续的原因分析及解决办法
2016/08/21 PHP
创建一个复制UBB软件信息的链接或按钮的js代码
2008/01/06 Javascript
jquery实现弹出层完美居中效果
2014/03/03 Javascript
JQuery实现的购物车功能(可以减少或者添加商品并自动计算价格)
2015/01/13 Javascript
JS上传图片前实现图片预览效果的方法
2015/03/02 Javascript
JavaScript动态添加列的方法
2015/03/25 Javascript
JS实现回到页面顶部动画效果的简单实例
2016/05/24 Javascript
浅谈JS正则表达式的RegExp对象和括号的使用
2016/07/28 Javascript
Vue.js Ajax动态参数与列表显示实现方法
2016/10/20 Javascript
js中的事件委托或是事件代理使用详解
2017/06/23 Javascript
分析JS中this引发的bug
2017/12/12 Javascript
利用JQUERY实现多个AJAX请求等待的实例
2017/12/14 jQuery
利用node实现一个批量重命名文件的函数
2017/12/21 Javascript
原生JS实现多个小球碰撞反弹效果示例
2018/01/31 Javascript
javaScript实现鼠标在文字上悬浮时弹出悬浮层效果
2020/04/12 Javascript
js实现input密码框显示/隐藏功能
2020/09/10 Javascript
Vue 实时监听窗口变化 windowresize的两种方法
2018/11/06 Javascript
使用js实现单链解决前端队列问题的方法
2020/02/03 Javascript
node.js 基于 STMP 协议和 EWS 协议发送邮件
2021/02/14 Javascript
python getopt详解及简单实例
2016/12/30 Python
关于Django ForeignKey 反向查询中filter和_set的效率对比详解
2018/12/15 Python
使用html2canvas将页面转成图并使用用canvas2image下载
2019/04/04 HTML / CSS
联想澳大利亚官网:Lenovo Australia
2018/01/18 全球购物
英国文具、办公用品和科技商店:Ryman
2018/09/27 全球购物
Coccinelle官网:意大利的著名皮具品牌
2019/05/15 全球购物
阿里巴巴英国:Alibaba英国
2019/12/11 全球购物
为数据库创建索引都需要注意些什么
2012/07/17 面试题
个人安全生产承诺书
2014/05/22 职场文书
质量安全标语
2014/06/07 职场文书
财务经理岗位职责范本
2015/04/08 职场文书
银行培训心得体会范文
2016/01/09 职场文书
Win11局域网共享权限在哪里设置? Win11高级共享的设置技巧
2022/04/05 数码科技