TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python进程间通信用法实例
Jun 04 Python
Python下Fabric的简单部署方法
Jul 14 Python
浅谈Python 的枚举 Enum
Jun 12 Python
Python实现购物车程序
Apr 16 Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 Python
详解Python下ftp上传文件linux服务器
Jun 21 Python
利用python画出折线图
Jul 26 Python
关于PyTorch 自动求导机制详解
Aug 18 Python
如何解决django-celery启动后迅速关闭
Oct 16 Python
Python函数生成器原理及使用详解
Mar 12 Python
python中not、and和or的优先级与详细用法介绍
Nov 03 Python
python 如何将两个实数矩阵合并为一个复数矩阵
May 19 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
Zend Studio 实用快捷键一览表(精心整理)
2013/08/10 PHP
解决AJAX中跨域访问出现'没有权限'的错误
2008/08/20 Javascript
jquery获得下拉框值的代码
2011/08/13 Javascript
基于jQuery的360图片展示实现代码
2012/06/14 Javascript
js 延迟加载 改变JS的位置加快网页加载速度
2012/12/11 Javascript
jQuery Mobile页面跳转后未加载外部JS原因分析及解决
2013/03/18 Javascript
js获取form的方法
2015/05/06 Javascript
angularjs学习笔记之双向数据绑定
2015/09/26 Javascript
javascript中去除数组重复元素的实现方法【实例】
2016/04/12 Javascript
解决JS组件bootstrap table分页实现过程中遇到的问题
2016/04/21 Javascript
jQuery图片轮播插件——前端开发必看
2016/05/31 Javascript
微信小程序 http请求的session管理
2017/06/07 Javascript
js 图片转base64的方式(两种)
2018/04/24 Javascript
jQuery实现的监听导航滚动置顶状态功能示例
2018/07/23 jQuery
Vuex的初探与实战小结
2018/11/26 Javascript
基于Vue实现平滑过渡的拖拽排序功能
2019/06/12 Javascript
解决nuxt 自定义全局方法,全局属性,全局变量的问题
2020/11/05 Javascript
跟老齐学Python之变量和参数
2014/10/10 Python
Python通过select实现异步IO的方法
2015/06/04 Python
解析Mac OS下部署Pyhton的Django框架项目的过程
2016/05/03 Python
python高效过滤出文件夹下指定文件名结尾的文件实例
2018/10/21 Python
解决yum对python依赖版本问题
2019/07/05 Python
解决Python设置函数调用超时,进程卡住的问题
2019/08/08 Python
python输出决策树图形的例子
2019/08/09 Python
Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取
2020/06/30 Python
python 数据类型强制转换的总结
2021/01/25 Python
css3 按钮 利用css3实现超酷下载按钮
2013/03/18 HTML / CSS
土耳其时尚购物网站:Morhipo
2017/09/04 全球购物
介绍一下Java的安全机制
2012/06/28 面试题
小学生班会演讲稿
2014/01/09 职场文书
优秀党员主要事迹
2014/01/19 职场文书
军校大学生个人的自我评价
2014/02/17 职场文书
希特勒的演讲稿
2014/05/23 职场文书
先进典型发言材料
2014/12/30 职场文书
个人求职意向书
2015/05/11 职场文书
2015年控辍保学工作总结
2015/05/18 职场文书