TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现全局变量的两个解决方法
Jul 03 Python
使用Python脚本将绝对url替换为相对url的教程
Apr 24 Python
Python EOL while scanning string literal问题解决方法
Sep 18 Python
python开发之list操作实例分析
Feb 22 Python
python类的继承实例详解
Mar 30 Python
浅谈python for循环的巧妙运用(迭代、列表生成式)
Sep 26 Python
python实现聚类算法原理
Feb 12 Python
python学习基础之循环import及import过程
Apr 22 Python
python通过zabbix api获取主机
Sep 17 Python
Python实现查找字符串数组最长公共前缀示例
Mar 27 Python
python ctypes库2_指定参数类型和返回类型详解
Nov 19 Python
Pandas读取csv时如何设置列名
Jun 02 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
虚拟主机中对PHP的特殊设置
2006/10/09 PHP
PHP中批量生成静态html(命令行下运行PHP)
2014/04/19 PHP
CodeIgniter配置之SESSION用法实例分析
2016/01/19 PHP
php中实现字符串翻转的方法
2017/02/22 PHP
使用XHProf查找PHP性能瓶颈的实例
2017/12/13 PHP
PHP8.0新功能之Match表达式的使用
2020/07/19 PHP
Jquery乱码的一次解决过程 图解教程
2010/02/20 Javascript
js如何获取兄弟、父类等节点
2014/01/06 Javascript
JS(JQuery)操作Array的相关方法介绍
2014/02/11 Javascript
基于JavaScript短信验证码如何实现
2016/01/24 Javascript
javascript 四十条常用技巧大全
2016/09/09 Javascript
JavaScript获取URL中参数querystring的方法详解
2016/10/11 Javascript
Validform表单验证总结篇
2016/10/31 Javascript
angular4模块中给标签添加背景图的实现方法
2017/09/15 Javascript
JS实现的A*寻路算法详解
2018/12/14 Javascript
[00:12]DAC2018 天才少年转战三号位,他的SOLO是否仍如昔日般强大?
2018/04/06 DOTA
跟老齐学Python之开始真正编程
2014/09/12 Python
简单介绍Python中的几种数据类型
2016/01/02 Python
Python实现在线音乐播放器
2017/03/03 Python
Python列表(List)知识点总结
2019/02/18 Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
2019/08/19 Python
python hash每次调用结果不同的原因
2019/11/21 Python
Python scrapy增量爬取实例及实现过程解析
2019/12/24 Python
Python3 xml.etree.ElementTree支持的XPath语法详解
2020/03/06 Python
Python基于Tkinter编写crc校验工具
2020/05/06 Python
如何使用python记录室友的抖音在线时间
2020/06/29 Python
两种CSS3伪类选择器详细介绍
2013/12/24 HTML / CSS
前端隐藏出边界内容的实现方法
2016/04/14 HTML / CSS
日本最大的药妆连锁店:Matsukiyo松本清药妆店
2017/11/23 全球购物
Chupi官网:在爱尔兰手工制作的订婚、结婚戒指和精美珠宝
2020/09/28 全球购物
房地产经营管理专业自荐信
2014/09/02 职场文书
2014县委书记党的群众路线教育实践活动对照检查材料思想汇报
2014/09/22 职场文书
办公室务虚会发言材料
2014/10/20 职场文书
2019年年中工作总结讲话稿模板
2019/03/25 职场文书
如何使用Python提取Chrome浏览器保存的密码
2021/06/09 Python
python 进阶学习之python装饰器小结
2021/09/04 Python