TensorFlow实现模型断点训练,checkpoint模型载入方式


Posted in Python onMay 26, 2020

深度学习中,模型训练一般都需要很长的时间,由于很多原因,导致模型中断训练,下面介绍继续断点训练的方法。

方法一:载入模型时,不必指定迭代次数,一般默认最新

# 保存模型
saver = tf.train.Saver(max_to_keep=1) # 最多保留最新的模型
 
# 开启会话
with tf.Session() as sess:
 # saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(20000))
 sess.run(tf.global_variables_initializer())
 ckpt = tf.train.get_checkpoint_state('./log/') # 注意此处是checkpoint存在的目录,千万不要写成‘./log'
 if ckpt and ckpt.model_checkpoint_path:
 saver.restore(sess,ckpt.model_checkpoint_path) # 自动恢复model_checkpoint_path保存模型一般是最新
 print("Model restored...")
 else:
 print('No Model')

方法二:载入时,指定想要载入模型的迭代次数

需要到Log文件夹中,查看当前迭代的次数,如下:此时为111000次。

TensorFlow实现模型断点训练,checkpoint模型载入方式

# 保存模型
saver = tf.train.Saver(max_to_keep=1)
# 开启会话
 
with tf.Session() as sess:
 saver.restore(sess, './log/' + "model_savemodel.cpkt-" + str(111000))
 sess.run(tf.global_variables_initializer())

载入模型后,会继续端点处的变量继续训练,那么是否可以减小剩余的需要的迭代次数?

模型断点训练效果展示:

训练到167000次后,载入模型重新训练。设置迭代次数为10000次,(d_step=1000)。原始设置的迭代的次数为1000000,已经训练了167000次。

Model restored...
Iter:0, D_loss:0.5139875411987305, G_loss:2.8023970127105713
Iter:1000, D_loss:0.4400891065597534, G_loss:2.781547784805298
Iter:2000, D_loss:0.5169454216957092, G_loss:2.58009934425354
Iter:3000, D_loss:0.4507023096084595, G_loss:2.584151268005371
Iter:4000, D_loss:0.5746167898178101, G_loss:2.5365757942199707
Iter:5000, D_loss:0.5288565158843994, G_loss:2.426676034927368
Iter:6000, D_loss:0.549595057964325, G_loss:2.820535659790039
Iter:7000, D_loss:0.32620012760162354, G_loss:2.540236473083496
Iter:8000, D_loss:0.4363398551940918, G_loss:2.5880446434020996
Iter:9000, D_loss:0.569464921951294, G_loss:2.5133447647094727
done!

保存的图片仍然从头开始编号,会覆盖掉之前的图片。

TensorFlow实现模型断点训练,checkpoint模型载入方式

以前对应编号的采样图片为:

TensorFlow实现模型断点训练,checkpoint模型载入方式

若有朋友有高见,还请不吝赐教。

补充知识:tensorflow加载训练好的模型及参数(读取checkpoint)

checkpoint 保存路径

model_path下存有包含多个迭代次数的模型

TensorFlow实现模型断点训练,checkpoint模型载入方式

1.获取最新保存的模型

即上图中的model-9400

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=tf.train.latest_checkpoint(model_path)
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

2.获取某个迭代次数的模型

比如上图中的model-9200

import tensorflow as tf

graph=tf.get_default_graph()  # 获取当前图
sess=tf.Session()
sess.run(tf.global_variables_initializer())

checkpoint_file=os.path.join(model_path,'model-9200')
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)

获取变量值

## 得到当前图中所有变量的名称
tensor_name_list=[tensor.name for tensor in graph.as_graph_def().node] 
# 查看所有变量
print(tensor_name_list) 

# 获取input_x和input_y的变量值
input_x = graph.get_operation_by_name("input_x").outputs[0]
input_y = graph.get_operation_by_name("input_y").outputs[0]

以上这篇TensorFlow实现模型断点训练,checkpoint模型载入方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python之ReportLab绘制条形码和二维码的实例
Jan 15 Python
python表格存取的方法
Mar 07 Python
Pandas标记删除重复记录的方法
Apr 08 Python
python中字符串的操作方法大全
Jun 03 Python
python numpy 一维数组转变为多维数组的实例
Jul 02 Python
Python调用服务接口的实例
Jan 03 Python
Python 隐藏输入密码时屏幕回显的实例
Feb 19 Python
python微信聊天机器人改进版(定时或触发抓取天气预报、励志语录等,向好友推送)
Apr 25 Python
python进程和线程用法知识点总结
May 28 Python
在 Jupyter 中重新导入特定的 Python 文件(场景分析)
Oct 27 Python
pymysql的简单封装代码实例
Jan 08 Python
Python Django中的STATIC_URL 设置和使用方式
Mar 27 Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
Python 如何批量更新已安装的库
May 26 #Python
tensorflow 20:搭网络,导出模型,运行模型的实例
May 26 #Python
Python自定义聚合函数merge与transform区别详解
May 26 #Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
May 26 #Python
You might like
Terran历史背景
2020/03/14 星际争霸
PHP中替换换行符的几种方法小结
2012/10/15 PHP
PHP中$_FILES的使用方法及注意事项说明
2014/02/14 PHP
使用新浪微博API的OAuth认证发布微博实例
2015/03/27 PHP
PHP中异常处理的一些方法整理
2015/07/03 PHP
thinkphp微信开之安全模式消息加密解密不成功的解决办法
2015/12/02 PHP
PHP基于swoole多进程操作示例
2019/08/12 PHP
学习YUI.Ext第五日--做拖放Darg&Drop
2007/03/10 Javascript
Document 对象的常用方法
2009/07/31 Javascript
使用jQuery.Validate进行客户端验证(初级篇) 不使用微软验证控件的理由
2010/06/28 Javascript
javascript正则表达式中参数g(全局)的作用
2010/11/11 Javascript
使用UglifyJS合并/压缩JavaScript的方法
2012/03/07 Javascript
javascript 获取HTML DOM父、子、临近节点
2014/06/16 Javascript
NodeJS制作爬虫全过程
2014/12/22 NodeJs
jQuery制作简洁的多级联动Select下拉框
2014/12/23 Javascript
jQuery使用$.get()方法从服务器文件载入数据实例
2015/03/25 Javascript
js实现的简单图片浮动效果完整实例
2016/05/10 Javascript
谈谈Vue.js——vue-resource全攻略
2017/01/16 Javascript
Vue开发过程中遇到的疑惑知识点总结
2017/01/20 Javascript
详解vue+css3做交互特效的方法
2017/11/20 Javascript
微信小程序中使用ECharts 异步加载数据的方法
2018/06/27 Javascript
layui-table表复选框勾选的所有行数据获取的例子
2019/09/13 Javascript
解决Angularjs异步操作后台请求用$q.all排列先后顺序问题
2019/11/29 Javascript
JSONObject与JSONArray使用方法解析
2020/09/28 Javascript
jQuery实现查看图片功能
2020/12/01 jQuery
vue实现简易计算器功能
2021/01/20 Vue.js
Python守护进程用法实例分析
2015/06/04 Python
python实现祝福弹窗效果
2019/04/07 Python
Django模板标签中url使用详解(url跳转到指定页面)
2020/03/19 Python
Hobbs官方网站:英国奢华女性时尚服装
2020/02/22 全球购物
数控技校生自我鉴定
2014/03/02 职场文书
12岁生日演讲稿
2014/05/14 职场文书
雨花台导游词
2015/02/06 职场文书
具结保证书范本
2015/05/11 职场文书
入党积极分子半年考察意见
2015/06/02 职场文书
品牌形象定位,全面分析
2019/07/23 职场文书