TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现监控linux性能及进程消耗性能的方法
Jul 25 Python
python开发之thread线程基础实例入门
Nov 11 Python
python中map()函数的使用方法示例
Sep 29 Python
python实现图书管理系统
Mar 12 Python
Python爬虫框架Scrapy常用命令总结
Jul 26 Python
Python通用循环的构造方法实例分析
Dec 19 Python
python实现合并两个排序的链表
Mar 03 Python
python设置环境变量的作用和实例
Jul 09 Python
Python符号计算之实现函数极限的方法
Jul 15 Python
解析Tensorflow之MNIST的使用
Jun 30 Python
pytorch 实现多个Dataloader同时训练
May 29 Python
Django+Celery实现定时任务的示例
Jun 23 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
风格模板初级不完全修改教程
2006/10/09 PHP
基于mysql的论坛(5)
2006/10/09 PHP
用header 发送cookie的php代码
2007/03/16 PHP
用PHP为SHOPEX增加日志功能代码
2010/07/02 PHP
php中关于普通表单多文件上传的处理方法
2011/03/25 PHP
php eval函数用法总结
2012/10/31 PHP
项目中应用Redis+Php的场景
2016/05/22 PHP
kindeditor 加入七牛云上传的实例讲解
2017/11/12 PHP
用js生产批量批处理执行命令
2008/07/28 Javascript
JQuery 网站换肤功能实现代码
2009/11/02 Javascript
IE的有条件注释判定IE版本详解(附实例代码)
2012/01/04 Javascript
setTimeout的延时为0时多个浏览器的区别
2012/05/23 Javascript
jQuery表单验证插件解析(推荐)
2016/07/21 Javascript
AngularJs 国际化(I18n/L10n)详解
2016/09/01 Javascript
js制作支付倒计时页面
2016/10/21 Javascript
js仿淘宝评价评分功能
2017/02/28 Javascript
Nodejs进阶之服务端字符编解码和乱码处理
2017/09/04 NodeJs
微信小程序实现默认第一个选中变色效果
2018/07/17 Javascript
vue 动态绑定背景图片的方法
2018/08/10 Javascript
vue自定义指令之面板拖拽的实现
2019/04/14 Javascript
微信小程序 搜索框组件代码实例
2019/09/06 Javascript
JS面向对象实现飞机大战
2020/08/26 Javascript
Python urlencode和unquote函数使用实例解析
2020/03/31 Python
python3安装OCR识别库tesserocr过程图解
2020/04/02 Python
英国最大的体育&时尚零售公司:JD Sports
2017/12/13 全球购物
Roxy俄罗斯官方网站:冲浪和滑雪板的一切
2020/06/20 全球购物
师范学院教师自荐书
2014/01/31 职场文书
市场营销专业应届生自荐信
2014/06/19 职场文书
金融与证券专业求职信
2014/06/22 职场文书
2014年远程教育工作总结
2014/12/09 职场文书
行政文员岗位职责
2015/02/04 职场文书
药店收银员岗位职责
2015/04/07 职场文书
2015年机关党建工作总结
2015/05/22 职场文书
2015年教师节主持词
2015/07/03 职场文书
linux下导入、导出mysql数据库命令的实现方法
2021/05/26 MySQL
MySQL系列之十三 MySQL的复制
2021/07/02 MySQL