tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
tensorflow: variable的值与variable.read_value()的值区别详解
Jul 30 Python
python3实现多线程聊天室
Dec 12 Python
在Python 中同一个类两个函数间变量的调用方法
Jan 31 Python
Python中Numpy ndarray的使用详解
May 24 Python
PyQt5实现从主窗口打开子窗口的方法
Jun 19 Python
Django 实现前端图片压缩功能的方法
Aug 07 Python
python打印文件的前几行或最后几行教程
Feb 13 Python
使用sklearn的cross_val_score进行交叉验证实例
Feb 28 Python
如何在django中运行scrapy框架
Apr 22 Python
基于Keras 循环训练模型跑数据时内存泄漏的解决方式
Jun 11 Python
用python实现一个简单计算器(完整DEMO)
Oct 14 Python
python 多线程爬取壁纸网站的示例
Feb 20 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
PHP函数篇详解十进制、二进制、八进制和十六进制转换函数说明
2011/12/05 PHP
Zend Framework入门教程之Zend_View组件用法示例
2016/12/09 PHP
Aster vs KG BO3 第三场2.18
2021/03/10 DOTA
Prototype String对象 学习
2009/07/19 Javascript
深入理解JavaScript中的传值与传引用
2013/12/09 Javascript
优化Node.js Web应用运行速度的10个技巧
2014/09/03 Javascript
node.js+express制作网页计算器
2016/01/17 Javascript
详解javascript传统方法实现异步校验
2016/01/22 Javascript
AngularJS 如何在控制台进行错误调试
2016/06/07 Javascript
JavaScript代码里的判断小结
2016/08/22 Javascript
js注入 黑客之路必备!
2016/09/14 Javascript
概述VUE2.0不可忽视的很多变化
2016/09/25 Javascript
JS实现搜索框文字可删除功能
2016/12/28 Javascript
ES6中module模块化开发实例浅析
2017/04/06 Javascript
为什么说JavaScript预解释是一种毫无节操的机制详析
2018/11/18 Javascript
微信小程序自定义导航教程(兼容各种手机)
2018/12/12 Javascript
Vuex mutitons和actions初使用详解
2019/03/04 Javascript
解决Layui 表格自适应高度的问题
2019/11/15 Javascript
使用element-ui +Vue 解决 table 里包含表单验证的问题
2020/07/17 Javascript
[01:10:02]IG vs Winstrike 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
python中使用urllib2获取http请求状态码的代码例子
2014/07/07 Python
Python 多进程和数据传递的理解
2017/10/09 Python
python实现内存监控系统
2021/03/07 Python
python交互界面的退出方法
2019/02/16 Python
Python3 中作为一等对象的函数解析
2019/12/11 Python
三个python爬虫项目实例代码
2019/12/28 Python
Pycharm创建python文件自动添加日期作者等信息(步骤详解)
2021/02/03 Python
澳大利亚床上用品、浴巾和家居用品购物网站:Bambury
2020/04/16 全球购物
大学毕业生最详细的自我评价分享
2013/11/18 职场文书
简单租房协议书
2014/04/09 职场文书
我爱我家教学反思
2014/05/01 职场文书
年终晚会活动方案
2014/08/21 职场文书
表扬稿范文
2015/01/17 职场文书
2016公司年会通知范文
2015/04/25 职场文书
客户答谢会致辞
2015/07/30 职场文书
python基础之匿名函数详解
2021/04/21 Python