tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python类的动态修改的实例方法
Mar 24 Python
Python3中条件控制、循环与函数的简易教程
Nov 21 Python
今天 平安夜 Python 送你一顶圣诞帽 @微信官方
Dec 25 Python
TensorFlow模型保存和提取的方法
Mar 08 Python
Python 中的range(),以及列表切片方法
Jul 02 Python
python爬取微信公众号文章的方法
Feb 26 Python
如何基于线程池提升request模块效率
Apr 18 Python
浅谈Python 函数式编程
Jun 20 Python
Pytest测试框架基本使用方法详解
Nov 25 Python
解决python 执行shell命令无法获取返回值的问题
Dec 05 Python
5个pandas调用函数的方法让数据处理更加灵活自如
Apr 24 Python
Python matplotlib安装以及实现简单曲线的绘制
Apr 26 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
php preg_match_all结合str_replace替换内容中所有img
2008/10/11 PHP
php数组函数序列 之array_count_values() 统计数组中所有值出现的次数函数
2011/10/29 PHP
超级实用的7个PHP代码片段分享
2012/01/05 PHP
PHP flock 文件锁详细介绍
2012/12/29 PHP
php使用socket post数据到其它web服务器的方法
2015/06/02 PHP
PHP简单处理表单输入的特殊字符的方法
2016/02/03 PHP
浅谈PHP中pack、unpack的详细用法
2018/03/12 PHP
Javascript的IE和Firefox兼容性汇编
2006/07/01 Javascript
javaScript 简单验证代码(用户名,密码,邮箱)
2009/09/28 Javascript
ext 列表页面关于多行查询的办法
2010/03/25 Javascript
避免 showModalDialog 弹出新窗体的原因分析
2010/05/31 Javascript
javascript调试说明
2010/06/07 Javascript
juqery 学习之五 文档处理 插入
2011/02/11 Javascript
js滚动条回到顶部的代码
2011/12/06 Javascript
Bootstrap学习笔记之css组件(3)
2016/06/07 Javascript
JQuery控制图片由中心点逐渐放大效果
2016/06/26 Javascript
jquery的checkbox,radio,select等方法小结
2016/08/30 Javascript
写gulp遇到的ES6问题详解
2018/12/03 Javascript
javascript写一个ajax自动拦截并下载数据代码实例
2019/09/07 Javascript
javascript 数组精简技巧小结
2020/02/26 Javascript
python抓取某汽车网数据解析html存入excel示例
2013/12/04 Python
Python中的列表生成式与生成器学习教程
2016/03/13 Python
python 实现自动远程登陆scp文件实例代码
2017/03/13 Python
python字符串的方法与操作大全
2018/01/30 Python
Pycharm连接远程服务器并实现远程调试的实现
2019/08/02 Python
在python中利用pycharm自定义代码块教程(三步搞定)
2020/04/15 Python
如何向接受结构参数的函数传入常数值
2016/02/17 面试题
销售心得体会
2014/01/02 职场文书
大学生职业规划书的范本
2014/02/18 职场文书
股指期货心得体会
2014/09/13 职场文书
教师正风肃纪剖析材料
2014/10/20 职场文书
2015年保安个人工作总结
2015/04/02 职场文书
幼儿园教师师德承诺书
2015/04/28 职场文书
邓小平文选读书笔记
2015/06/29 职场文书
《世界多美呀》教学反思
2016/02/22 职场文书
MySQL中int (10) 和 int (11) 的区别
2022/01/22 MySQL