TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python发送伪造的arp请求
Jan 09 Python
python顺序的读取文件夹下名称有序的文件方法
Jul 11 Python
Python列表对象实现原理详解
Jul 01 Python
python的debug实用工具 pdb详解
Jul 12 Python
python对常见数据类型的遍历解析
Aug 27 Python
使用pyshp包进行shapefile文件修改的例子
Dec 06 Python
python实现银行实战系统
Feb 26 Python
python实现将中文日期转换为数字日期
Jul 14 Python
Python用来做Web开发的优势有哪些
Aug 05 Python
Python使用正则表达式实现爬虫数据抽取
Aug 17 Python
一文搞懂如何实现Go 超时控制
Mar 30 Python
详细介绍python操作RabbitMq
Apr 12 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
自己前几天写的无限分类类
2007/02/14 PHP
PHP实现股票趋势图和柱形图
2015/02/07 PHP
Javascript实例教程(19) 使用HoTMetal(3)
2006/12/23 Javascript
javascript (用setTimeout而非setInterval)
2011/12/28 Javascript
jquery实现隐藏与显示动画效果/输入框字符动态递减/导航按钮切换
2013/07/01 Javascript
Javascript实现简单二级下拉菜单实例
2014/06/15 Javascript
js限制文本框只能输入整数或者带小数点的数字
2015/04/27 Javascript
JS+CSS实现的拖动分页效果实例
2015/05/11 Javascript
JS建造者模式基本用法实例分析
2015/06/30 Javascript
JavaScript模板引擎用法实例
2015/07/10 Javascript
基于jQuery通过jQuery.form.js插件实现异步上传
2015/12/13 Javascript
js定义类的几种方法(推荐)
2016/06/08 Javascript
AngularJS中的缓存使用
2017/01/11 Javascript
JS中把函数作为另一函数的参数传递方法(总结)
2017/06/28 Javascript
VueJs 将接口用webpack代理到本地的方法
2017/11/27 Javascript
vue2 中二级路由高亮问题及配置方法
2019/06/10 Javascript
React 全自动数据表格组件——BodeGrid的实现思路
2019/06/12 Javascript
Vue内部渲染视图的方法
2019/09/02 Javascript
Python 如何访问外围作用域中的变量
2016/09/11 Python
python 打印直角三角形,等边三角形,菱形,正方形的代码
2017/11/21 Python
python合并同类型excel表格的方法
2018/04/01 Python
Python实现的json文件读取及中文乱码显示问题解决方法
2018/08/06 Python
django query模块
2019/04/20 Python
通过 Django Pagination 实现简单分页功能
2019/11/11 Python
python中Lambda表达式详解
2019/11/20 Python
为什么黑客都用python(123个黑客必备的Python工具)
2020/01/31 Python
python with (as)语句实例详解
2020/02/04 Python
Python pip使用超时问题解决方案
2020/08/03 Python
Python如何批量生成和调用变量
2020/11/21 Python
在 Python 中使用 7zip 备份文件的操作
2020/12/11 Python
html5移动端自适应布局的实现
2020/04/15 HTML / CSS
英语演讲稿范文
2014/01/03 职场文书
拖欠货款起诉状
2015/05/20 职场文书
《坐井观天》教学反思
2016/02/18 职场文书
2019年自助餐厅创业计划书模板
2019/08/22 职场文书
python图片灰度化处理的几种方法
2021/06/23 Python