TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则分组的应用
Nov 10 Python
解决nohup重定向python输出到文件不成功的问题
May 11 Python
Python中偏函数用法示例
Jun 07 Python
解决python 自动安装缺少模块的问题
Oct 22 Python
python3文件复制、延迟文件复制任务的实现方法
Sep 02 Python
python实现的登录与提交表单数据功能示例
Sep 25 Python
pytorch 实现打印模型的参数值
Dec 30 Python
Python操作Excel把数据分给sheet
May 20 Python
基于python实现可视化生成二维码工具
Jul 08 Python
django数据模型中null和blank的区别说明
Sep 02 Python
OpenCV实现机器人对物体进行移动跟随的方法实例
Nov 09 Python
Python实现Matplotlib,Seaborn动态数据图
May 06 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
微信小程序 消息推送php服务器验证实例详解
2017/03/30 PHP
自适应图片大小的弹出窗口
2006/07/27 Javascript
Javascript优化技巧(文件瘦身篇)
2008/01/28 Javascript
五段实用的js高级技巧
2011/12/20 Javascript
使用jQuery清空file文件域的解决方案
2013/04/12 Javascript
解析Javascript中大括号“{}”的多义性
2013/12/02 Javascript
通过正则表达式实现表单验证是否为中文
2014/02/18 Javascript
详解JavaScript ES6中的模板字符串
2015/07/28 Javascript
JS实现常见的TAB、弹出层效果(TAB标签,斑马线,遮罩层等)
2015/10/08 Javascript
常常会用到的截取字符串substr()、substring()、slice()方法详解
2015/12/16 Javascript
js中flexible.js实现淘宝弹性布局方案
2020/06/23 Javascript
jQuery基于正则表达式的表单验证功能示例
2017/01/21 Javascript
使用JavaScript判断用户输入的是否为正整数(两种方法)
2017/02/05 Javascript
基于JavaScript实现拖动滑块效果
2017/02/16 Javascript
vue 多入口文件搭建 vue多页面搭建的实例讲解
2018/03/12 Javascript
Angular HMR(热模块替换)功能实现方法
2018/04/04 Javascript
使用Vue动态生成form表单的实例代码
2018/04/26 Javascript
Vue props 单向数据流的实现
2018/11/06 Javascript
Bootstrap Paginator+PageHelper实现分页效果
2018/12/29 Javascript
使用纯前端JavaScript实现Excel导入导出方法过程详解
2020/08/07 Javascript
[08:44]和酒神一起战斗 DOTA2教你做大人
2014/03/27 DOTA
Python中MYSQLdb出现乱码的解决方法
2014/10/11 Python
详解Python3中的Sequence type的使用
2015/08/01 Python
用Python3创建httpServer的简单方法
2018/06/04 Python
python中update的基本使用方法详解
2019/07/17 Python
pytorch-RNN进行回归曲线预测方式
2020/01/14 Python
中国最大的潮流商品购物网站:YOHO!BUY有货
2017/01/07 全球购物
导游实习生自荐书
2014/01/28 职场文书
班主任工作经验材料
2014/02/02 职场文书
《池塘边的叫声》教学反思
2014/04/12 职场文书
电子信息工程专业自荐书
2014/06/24 职场文书
师德师风自查材料
2014/10/14 职场文书
实习计划书范文
2015/01/16 职场文书
劳动合同变更协议书范本
2019/04/18 职场文书
Spring Boot mybatis-config 和 log4j 输出sql 日志的方式
2021/07/26 Java/Android
Tomcat项目启动失败的原因和解决办法
2022/04/20 Servers