tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python Django做网页
Nov 04 Python
wxPython定时器wx.Timer简单应用实例
Jun 03 Python
使用Python编写基于DHT协议的BT资源爬虫
Mar 19 Python
Python编程判断一个正整数是否为素数的方法
Apr 14 Python
python dataframe 输出结果整行显示的方法
Jun 14 Python
Windows下Anaconda2安装NLTK教程
Sep 19 Python
python实现在图片上画特定大小角度矩形框
Oct 24 Python
详解Python中pandas的安装操作说明(傻瓜版)
Apr 08 Python
Python上下文管理器类和上下文管理器装饰器contextmanager用法实例分析
Nov 07 Python
使用matplotlib绘制图例标签中带有公式的图
Dec 13 Python
matlab灰度图像调整及imadjust函数的用法详解
Feb 27 Python
Pytorch框架实现mnist手写库识别(与tensorflow对比)
Jul 20 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
PHP源码之 ext/mysql扩展部分
2009/07/17 PHP
php将字符串全部转换成大写或者小写的方法
2015/03/17 PHP
php将print_r处理后的数据还原为原始数组的解决方法
2016/11/02 PHP
基于jquery的超简单上下翻
2010/04/20 Javascript
一个简单的网站访问JS计数器 刷新1次加1次访问
2012/09/20 Javascript
javascript遍历控件实例详细解析
2014/01/10 Javascript
jquery新的绑定事件机制on方法的使用方法
2014/04/15 Javascript
javascript学习笔记(四)function函数部分
2014/09/30 Javascript
jQuery实现HTML5 placeholder效果实例
2014/12/09 Javascript
js鼠标滑过图片震动特效的方法
2015/02/17 Javascript
深入浅析AngularJS中的一次性数据绑定 (bindonce)
2017/05/11 Javascript
Vue-resource拦截器判断token失效跳转的实例
2017/10/27 Javascript
解决VUE中document.body.scrollTop为0的问题
2018/09/15 Javascript
Python 匹配任意字符(包括换行符)的正则表达式写法
2009/10/29 Python
Python用GET方法上传文件
2015/03/10 Python
详解Python中用于计算指数的exp()方法
2015/05/14 Python
Python os模块学习笔记
2015/06/21 Python
python脚本实现xls(xlsx)转成csv
2016/04/10 Python
python使用RNN实现文本分类
2018/05/24 Python
python中实现字符串翻转的方法
2018/07/11 Python
对python3中的RE(正则表达式)-详细总结
2019/07/23 Python
基于python二叉树的构造和打印例子
2019/08/09 Python
python中文分词库jieba使用方法详解
2020/02/11 Python
简单了解python列表和元组的区别
2020/05/14 Python
Python如何将将模块分割成多个文件
2020/08/04 Python
pycharm中选中一个单词替换所有重复单词的实现方法
2020/11/17 Python
python实现学生信息管理系统(精简版)
2020/11/27 Python
HTML5未来发展趋势
2016/02/01 HTML / CSS
全球最大最受欢迎的旅游社区:Tripadvisor
2017/11/03 全球购物
女孩每月服装订阅盒:kidpik
2019/04/17 全球购物
.net C#面试题
2012/08/28 面试题
公司副总经理任命书
2014/06/05 职场文书
2014年助理工程师工作总结
2014/11/14 职场文书
2019年教师节祝福语精选,给老师送上真诚的祝福
2019/09/09 职场文书
Nginx实现会话保持的两种方式
2022/03/18 Servers
基于Python实现西西成语接龙小助手
2022/08/05 Golang