TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用metaclass实现Singleton模式的方法
May 05 Python
独特的python循环语句
Nov 20 Python
老生常谈python的私有公有属性(必看篇)
Jun 09 Python
Python机器学习算法之k均值聚类(k-means)
Feb 23 Python
Python selenium抓取微博内容的示例代码
May 17 Python
django框架之cookie/session的使用示例(小结)
Oct 15 Python
在Pycharm中调试Django项目程序的操作方法
Jul 17 Python
python 叠加等边三角形的绘制的实现
Aug 14 Python
python如何判断IP地址合法性
Apr 05 Python
pycharm配置python 设置pip安装源为豆瓣源
Feb 05 Python
Django实现在线无水印抖音视频下载(附源码及地址)
May 06 Python
python之django路由和视图案例教程
Jul 26 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
php使用cookie实现记住登录状态
2015/04/27 PHP
javascript offsetX与layerX区别
2010/03/12 Javascript
JQuery入门——事件切换之toggle()方法应用介绍
2013/02/05 Javascript
JS获取鼠标坐标的实例方法
2013/07/18 Javascript
返回页面顶部top按钮通过锚点实现(自写)
2013/08/30 Javascript
jQuery.fn和jQuery.prototype区别介绍
2013/10/05 Javascript
JSON中双引号的轮回使用过程中一定要小心
2014/03/05 Javascript
jQuery中wrapInner()方法用法实例
2015/01/16 Javascript
jQuery实现简单的文件上传进度条效果
2020/03/26 Javascript
Angularjs CURD 详解及实例代码
2016/09/14 Javascript
JavaScript 监控微信浏览器且自带返回按钮时间
2016/11/27 Javascript
nodejs连接mysql数据库简单封装示例-mysql模块
2017/04/10 NodeJs
基于JavaScript表单脚本(详解)
2017/10/18 Javascript
vue.js中引入vuex储存接口数据及调用的详细流程
2017/12/14 Javascript
JavaScript 五大常见函数
2018/03/23 Javascript
VueJs组件之父子通讯的方式
2018/05/06 Javascript
jquery.onoff实现简单的开关按钮功能(推荐)
2018/05/24 jQuery
angular ng-model 无法获取值的处理方法
2018/10/02 Javascript
react 应用多入口配置及实践总结
2018/10/17 Javascript
深入理解Python中的内置常量
2017/05/20 Python
Python实现的手机号归属地相关信息查询功能示例
2017/06/08 Python
Django migrations 默认目录修改的方法教程
2018/09/28 Python
django orm 通过related_name反向查询的方法
2018/12/15 Python
Python通过TensorFlow卷积神经网络实现猫狗识别
2019/03/14 Python
django框架使用方法详解
2019/07/18 Python
django实现HttpResponse返回json数据为中文
2020/03/27 Python
Python利用Pillow(PIL)库实现验证码图片的全过程
2020/10/04 Python
求职简历自我评价范例
2014/03/12 职场文书
建筑管理专业求职信
2014/07/28 职场文书
单位介绍信格式
2015/01/31 职场文书
爱牙日宣传活动总结
2015/02/05 职场文书
2015年个人审计工作总结
2015/04/07 职场文书
2016保送生自荐信范文
2016/01/29 职场文书
检讨书范文
2019/04/16 职场文书
MySQL注入基础练习
2021/05/30 MySQL
mysql幻读详解实例以及解决办法
2022/06/16 MySQL