tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现把utf-8格式的文件转换成gbk格式的文件
Jan 22 Python
Python画图学习入门教程
Jul 01 Python
Python 装饰器实现DRY(不重复代码)原则
Mar 05 Python
Django 中使用流响应处理视频的方法
Jul 20 Python
Scrapy使用的基本流程与实例讲解
Oct 21 Python
Python实现Event回调机制的方法
Feb 13 Python
PySide和PyQt加载ui文件的两种方法
Feb 27 Python
django 邮件发送模块smtp使用详解
Jul 22 Python
深入了解python中元类的相关知识
Aug 29 Python
Python中zip()函数的简单用法举例
Sep 02 Python
Python3如何使用tabulate打印数据
Sep 25 Python
详解Selenium 元素定位和WebDriver常用方法
Dec 04 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
PHP无限分类(树形类)的深入分析
2013/06/02 PHP
滚动经典最新话题[prototype框架]下编写
2006/10/03 Javascript
Javascript的构造函数和constructor属性
2010/01/09 Javascript
JavaScript页面刷新与弹出窗口问题的解决方法
2010/03/02 Javascript
js下判断 iframe 是否加载完成的完美方法
2010/10/26 Javascript
javascript 基础篇2 数据类型,语句,函数
2012/03/14 Javascript
解决js数据包含加号+通过ajax传到后台时出现连接错误
2013/08/01 Javascript
jquery将一个表单序列化为一个对象的方法
2013/12/02 Javascript
javascript动态向网页中添加表格实现代码
2014/02/19 Javascript
Javascript实现颜色rgb与16进制转换的方法
2015/04/18 Javascript
JavaScript中操作字符串之localeCompare()方法的使用
2015/06/06 Javascript
Bootstrap中的fileinput 多图片上传及编辑功能
2016/09/05 Javascript
webpack处理 css\less\sass 样式的方法
2017/08/21 Javascript
Angular2里获取(input file)上传文件的内容的方法
2017/09/05 Javascript
elementUI 设置input的只读或禁用的方法
2018/10/30 Javascript
在Vant的基础上封装下拉日期控件的代码示例
2018/12/05 Javascript
JS typeof fn === 'function' && fn()详解
2020/08/22 Javascript
微信小程序接入vant Weapp组件的详细步骤
2020/10/28 Javascript
Vue+scss白天和夜间模式切换功能的实现方法
2021/01/05 Vue.js
python 读写txt文件 json文件的实现方法
2016/10/22 Python
15行Python代码带你轻松理解令牌桶算法
2018/03/21 Python
django将图片上传数据库后在前端显式的方法
2018/05/25 Python
解析Python的缩进规则的使用
2019/01/16 Python
CSS3实现任意图片lowpoly动画效果实例
2017/05/11 HTML / CSS
英国高档时尚男装购物网站:MR PORTER
2016/08/09 全球购物
全球最大的在线旅游公司:Expedia
2017/11/16 全球购物
巴西儿童时尚购物网站:Dinda
2019/08/14 全球购物
英国旅行箱包和行李箱购物网站:Travel Luggage & Cabin Bags
2019/08/26 全球购物
餐厅采购员岗位职责
2014/03/06 职场文书
小学生演讲稿大全
2014/04/25 职场文书
青年文明号口号
2014/06/17 职场文书
员工试用期自我鉴定范文
2014/09/15 职场文书
2015年干部教育培训工作总结
2015/05/15 职场文书
法制教育观后感
2015/06/17 职场文书
《从现在开始》教学反思
2016/02/16 职场文书
Python中X[:,0]和X[:,1]的用法
2021/05/10 Python