tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python实现一个简单的项目监控
Mar 31 Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 Python
Python复制Word内容并使用格式设字体与大小实例代码
Jan 22 Python
PyCharm设置护眼背景色的方法
Oct 29 Python
Python read函数按字节(字符)读取文件的实现
Jul 03 Python
pyqt5 QScrollArea设置在自定义侧(任何位置)
Sep 25 Python
python实现连连看游戏
Feb 14 Python
基于pygame实现童年掌机打砖块游戏
Feb 25 Python
Pytorch框架实现mnist手写库识别(与tensorflow对比)
Jul 20 Python
python 读取串口数据的示例
Nov 09 Python
Python实现列表索引批量删除的5种方法
Nov 16 Python
详解Python常用的魔法方法
Jun 03 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
php中的strpos使用示例
2014/02/27 PHP
thinkphp3.2实现在线留言提交验证码功能
2017/07/19 PHP
Laravel中encrypt和decrypt的实现方法
2017/09/24 PHP
Laravel 5.5 异常处理 & 错误日志的解决
2019/10/17 PHP
javascript生成/解析dom的CDATA类型的字段的代码
2007/04/22 Javascript
深入理解JavaScript系列(16) 闭包(Closures)
2012/04/12 Javascript
JavaScript的常见兼容问题及相关解决方法(chrome/IE/firefox)
2013/12/31 Javascript
一些老手都不一定知道的JavaScript技巧
2014/05/06 Javascript
使用javascript实现雪花飘落的效果
2015/01/13 Javascript
js实现鼠标触发图片抖动效果的方法
2015/02/27 Javascript
JavaScript Array对象详解
2016/03/01 Javascript
js判断空对象的实例(超简单)
2016/07/26 Javascript
jQuery选择器总结之常用元素查找方法
2016/08/04 Javascript
Javascript中字符串相关常用的使用方法总结
2017/03/13 Javascript
Nodejs之http的表单提交
2017/07/07 NodeJs
js实现手机web图片左右滑动效果
2017/12/29 Javascript
vue forEach循环数组拿到自己想要的数据方法
2018/09/21 Javascript
js核心基础之构造函数constructor用法实例分析
2019/05/11 Javascript
Vue基于localStorage存储信息代码实例
2020/11/16 Javascript
[01:35:13]DOTA2-DPC中国联赛 正赛 DLG vs PHOENIX BO3 第一场 1月18日
2021/03/11 DOTA
Python对CSV、Excel、txt、dat文件的处理
2018/09/18 Python
Python3之手动创建迭代器的实例代码
2019/05/22 Python
Python实现CAN报文转换工具教程
2020/05/05 Python
selenium3.0+python之环境搭建的方法步骤
2021/02/01 Python
浅谈CSS3中的变形功能-transform功能
2017/12/27 HTML / CSS
巴西食品补充剂在线零售商:Músculos na Web
2017/08/07 全球购物
Foot Locker德国官方网站:美国运动服和鞋类零售商
2018/11/01 全球购物
加拿大在线隐形眼镜和眼镜店:VisionPros
2019/10/06 全球购物
党员作风建设自查报告
2014/10/23 职场文书
会计工作检讨书
2015/02/19 职场文书
合同审查法律意见书
2015/06/04 职场文书
美德少年事迹材料(2016推荐版)
2016/02/25 职场文书
Django中session进行权限管理的使用
2021/07/09 Python
MySQL空间数据存储及函数
2021/09/25 MySQL
springcloud整合seata
2022/05/20 Java/Android
css之clearfix的用法深入理解(必看篇)
2023/05/21 HTML / CSS