tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxPython窗口中文乱码解决方法
Oct 11 Python
Python实现的简单文件传输服务器和客户端
Apr 08 Python
Python实现二分查找算法实例
May 26 Python
Python批量更改文件名的实现方法
Oct 29 Python
numpy.linspace函数具体使用详解
May 27 Python
Django框架封装外部函数示例
May 28 Python
Python Django 封装分页成通用的模块详解
Aug 21 Python
python sklearn常用分类算法模型的调用
Oct 16 Python
numpy.array 操作使用简单总结
Nov 08 Python
wxPython实现分隔窗口
Nov 19 Python
Python调用钉钉自定义机器人的实现
Jan 03 Python
Pycharm安装python库的方法
Nov 24 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
ThinkPHP中的关联模型注意点
2014/06/16 PHP
smarty实现多级分类的方法
2014/12/05 PHP
Zend Framework开发入门经典教程
2016/03/23 PHP
Prototype Date对象 学习
2009/07/12 Javascript
flexigrid 参数说明
2010/11/23 Javascript
cookie在javascript中的使用技巧以及隐私在服务器端的设置
2012/12/03 Javascript
查找页面中所有类为test的结点的方法
2014/03/28 Javascript
JavaScript中具名函数的多种调用方式总结
2014/11/08 Javascript
js全选实现和判断是否有复选框选中的方法
2015/02/17 Javascript
javascript制作2048游戏
2015/03/30 Javascript
JS实现可直接显示网页代码运行效果的HTML代码预览功能实例
2015/08/06 Javascript
JS实现合并两个数组并去除重复项只留一个的方法
2015/12/17 Javascript
基于javascript实现随机颜色变化效果
2016/01/14 Javascript
jQuery基于扩展简单实现倒计时功能的方法
2016/05/14 Javascript
BootStrap Select清除选中的状态恢复默认状态
2017/06/20 Javascript
angular+ionic返回上一页并刷新页面
2017/08/08 Javascript
微信小程序模板和模块化用法实例分析
2017/11/28 Javascript
解决vue-cli + webpack 新建项目出错的问题
2018/03/20 Javascript
vue动画之点击按钮往上渐渐显示出来的实例
2018/09/29 Javascript
js常见遍历操作小结
2019/06/06 Javascript
vue的$http的get请求要加上params操作
2020/11/12 Javascript
[58:58]2018DOTA2亚洲邀请赛 4.4 淘汰赛 TNC vs VG 第二场
2018/04/05 DOTA
Python使用multiprocessing创建进程的方法
2015/06/04 Python
Python中的左斜杠、右斜杠(正斜杠和反斜杠)
2016/08/30 Python
Pycharm远程调试openstack的方法
2017/11/21 Python
python3 发送任意文件邮件的实例
2018/01/23 Python
python保存文件方法小结
2018/07/27 Python
浅谈pyqt5中信号与槽的认识
2019/02/17 Python
Python弹出输入框并获取输入值的实例
2019/06/18 Python
用python求一重积分和二重积分的例子
2019/12/06 Python
python实现同一局域网下传输图片
2020/03/20 Python
浅谈移动端网页图片预加载方案
2018/11/05 HTML / CSS
丹尼尔惠灵顿手表天猫官方旗舰店:Daniel Wellington
2017/08/25 全球购物
新闻编辑专业毕业自荐书范文
2014/02/05 职场文书
SQL 聚合、分组和排序
2021/11/11 MySQL
js面向对象编程OOP及函数式编程FP区别
2022/07/07 Javascript