tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Tkinter简单布局实例教程
Sep 03 Python
举例讲解Python面相对象编程中对象的属性与类的方法
Jan 19 Python
浅谈Python中函数的参数传递
Jun 21 Python
python爬虫面试宝典(常见问题)
Mar 02 Python
Python cookbook(数据结构与算法)将名称映射到序列元素中的方法
Mar 22 Python
python中plot实现即时数据动态显示方法
Jun 22 Python
Django进阶之CSRF的解决
Aug 01 Python
python启动应用程序和终止应用程序的方法
Jun 28 Python
python模块导入的方法
Oct 24 Python
Python Sphinx使用实例及问题解决
Jan 17 Python
python判断正负数方式
Jun 03 Python
Python中X[:,0]和X[:,1]的用法
May 10 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
留言板翻页的实现详解
2006/10/09 PHP
PHP实现抓取Google IP并自动修改hosts文件
2015/02/12 PHP
不用AJAX和IFRAME,说说真正意义上的ASP+JS无刷新技术
2008/09/25 Javascript
Zero Clipboard js+swf实现的复制功能使用方法
2010/03/07 Javascript
深入理解JavaScript 闭包究竟是什么
2013/04/12 Javascript
jquery实现两个图片渐变切换效果的方法
2015/06/25 Javascript
最常见的左侧分类菜单栏jQuery实现代码
2016/11/28 Javascript
微信页面倒计时代码(解决safari不兼容date的问题)
2016/12/13 Javascript
JS实现课堂随机点名和顺序点名
2017/03/09 Javascript
JS实现图片预加载之无序预加载功能代码
2017/05/12 Javascript
Jquery中attr与prop的区别详解
2017/05/27 jQuery
php中and 和 &&出坑指南
2018/07/13 Javascript
Vue动态生成el-checkbox点击无法赋值的解决方法
2019/02/21 Javascript
JavaScript ECMA-262-3 深入解析(二):变量对象实例详解
2020/04/25 Javascript
vue相关配置文件详解及多环境配置详细步骤
2020/05/19 Javascript
vue 函数调用加括号与不加括号的区别
2020/10/29 Javascript
python之消除前缀重命名的方法
2018/10/21 Python
Python中安装easy_install的方法
2018/11/18 Python
python使用thrift教程的方法示例
2019/03/21 Python
用Python配平化学方程式的方法
2019/07/20 Python
python pygame实现球球大作战
2019/11/25 Python
python操作gitlab API过程解析
2019/12/27 Python
Python 实现 T00ls 自动签到脚本代码(邮件+钉钉通知)
2020/07/06 Python
澳大利亚最大的女装零售商:Millers
2017/09/10 全球购物
香港士多网上超级市场:Ztore
2021/01/09 全球购物
小学教师师德反思
2014/02/03 职场文书
社区网格化管理实施方案
2014/03/21 职场文书
安全承诺书范文
2014/03/26 职场文书
法制宣传日活动总结
2014/04/29 职场文书
会员卡清退活动总结
2014/08/27 职场文书
2014审计局领导班子民主生活会对照检查材料思想汇报
2014/09/20 职场文书
现实表现材料范文
2014/12/23 职场文书
单位接收函范文
2015/01/30 职场文书
自荐信格式模板
2015/03/27 职场文书
债务纠纷起诉书
2015/05/20 职场文书
Python编程中Python与GIL互斥锁关系作用分析
2021/09/15 Python