tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作Word批量生成文章的方法
Jul 28 Python
Python中用psycopg2模块操作PostgreSQL方法
Nov 28 Python
Django中使用celery完成异步任务的示例代码
Jan 23 Python
利用Python如何将数据写到CSV文件中
Jun 05 Python
Python DataFrame 设置输出不显示index(索引)值的方法
Jun 07 Python
Python单元测试unittest的具体使用示例
Dec 17 Python
Python字典推导式将cookie字符串转化为字典解析
Aug 10 Python
Django 创建后台,配置sqlite3教程
Nov 18 Python
在python里使用await关键字来等另外一个协程的实例
May 04 Python
解决Pycharm双击图标启动不了的问题(JetBrains全家桶通用)
Aug 07 Python
Django跨域请求原理及实现代码
Nov 14 Python
利用python实时刷新基金估值(摸鱼小工具)
Sep 15 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
php中mysql连接方式PDO使用详解
2015/02/25 PHP
phpmyadmin下载、安装、配置教程
2017/05/16 PHP
javascript 面向对象编程  function是方法(函数)
2009/09/17 Javascript
Javascript实现颜色rgb与16进制转换的方法
2015/04/18 Javascript
延时加载JavaScript代码提高速度
2015/12/27 Javascript
5种JavaScript脚本加载的方式
2017/01/16 Javascript
React Native时间转换格式工具类分享
2017/10/24 Javascript
vue-cli + sass 的正确打开方式图文详解
2017/10/27 Javascript
AnglarJs中的上拉加载实现代码
2018/02/08 Javascript
详解Vue 多级组件透传新方法provide/inject
2018/05/09 Javascript
Node.js+ELK日志规范的实现
2019/05/23 Javascript
微信小程序实现列表左右滑动
2020/11/19 Javascript
[00:31]2016完美“圣”典风云人物:国士无双宣传片
2016/12/04 DOTA
[01:17:55]VGJ.T vs Mineski 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/20 DOTA
Python 列表(List)操作方法详解
2014/03/11 Python
浅谈python中的实例方法、类方法和静态方法
2017/02/17 Python
关于Python如何避免循环导入问题详解
2017/09/14 Python
python3使用requests模块爬取页面内容的实战演练
2017/09/25 Python
python+selenium实现简历自动刷新的示例代码
2019/05/20 Python
python实现合并多个list及合并多个django QuerySet的方法示例
2019/06/11 Python
Django如何将URL映射到视图
2019/07/29 Python
Python 单例设计模式用法实例分析
2019/09/23 Python
Python 迭代,for...in遍历,迭代原理与应用示例
2019/10/12 Python
Python 转换RGB颜色值的示例代码
2019/10/13 Python
python生成器用法实例详解
2019/11/22 Python
Matplotlib中%matplotlib inline如何使用
2020/07/28 Python
Python pickle模块常用方法代码实例
2020/10/10 Python
HTML5有哪些新特征
2015/12/01 HTML / CSS
阿迪达斯比利时官方商城:adidas比利时
2016/10/10 全球购物
体育纪念品、亲笔签名的体育收藏品:Steiner Sports
2020/07/31 全球购物
医学检验专业大学生求职信
2013/11/18 职场文书
医学专业大学生职业生涯规划书
2014/10/25 职场文书
《认识钟表》教学反思
2016/02/16 职场文书
html5调用摄像头实例代码
2021/06/28 HTML / CSS
spring boot项目application.properties文件存放及使用介绍
2021/06/30 Java/Android
Android Canvas绘制文字横纵向对齐
2022/06/05 Java/Android