tensorflow如何继续训练之前保存的模型实例


Posted in Python onJanuary 21, 2020

一:需重定义神经网络继续训练的方法

1.训练代码

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32) 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
 
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
 
init=tf.global_variables_initializer() 
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step) #保存
  print("当前进行:",step)

第一次训练截图:

tensorflow如何继续训练之前保存的模型实例

2.恢复上一次的训练

import numpy as np
 
import tensorflow as tf
 
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
 
print(sess.run("w:0"),sess.run("b:0"))
 
 
 
graph=tf.get_default_graph() 
weight=graph.get_tensor_by_name("w:0") 
biases=graph.get_tensor_by_name("b:0")
 
 
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
 
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(train)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

使用上次保存下的数据进行继续训练和保存:

tensorflow如何继续训练之前保存的模型实例

#最后要提一下的是:

checkpoint文件

meta保存了TensorFlow计算图的结构信息

datat保存每个变量的取值

index保存了 表

加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的

这个方法需要重新定义神经网络

二:不需要重新定义神经网络的方法:

在上面训练的代码中加入:tf.add_to_collection("name",参数)

import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
 
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
 
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
 
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
 
for step in range(10):
  sess.run(train)
  saver.save(sess,"./save_mode",global_step=step)
  print("当前进行:",step)

在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了

import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
 
y=tf.get_collection("new_way")[0]
 
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
  sess.run(y)
  saver.save(sess,r"./save_new_mode",global_step=step)
  print("当前进行:",step," ",sess.run(weight),sess.run(biases))

总的来说,下面这种方法好像是要便利一些

以上这篇tensorflow如何继续训练之前保存的模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python编程之多态用法实例详解
May 19 Python
Python实现的人工神经网络算法示例【基于反向传播算法】
Nov 11 Python
python版学生管理系统
Jan 10 Python
python实现对文件中图片生成带标签的txt文件方法
Apr 27 Python
python中的decorator的作用详解
Jul 26 Python
django框架实现模板中获取request 的各种信息示例
Jul 01 Python
python爬虫 线程池创建并获取文件代码实例
Sep 28 Python
利用pyshp包给shapefile文件添加字段的实例
Dec 06 Python
keras获得某一层或者某层权重的输出实例
Jan 24 Python
解决使用python print打印函数返回值多一个None的问题
Apr 09 Python
如何从csv文件构建Tensorflow的数据集
Sep 21 Python
python元组打包和解包过程详解
Aug 02 Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
TensorFlow——Checkpoint为模型添加检查点的实例
Jan 21 #Python
tensorflow estimator 使用hook实现finetune方式
Jan 21 #Python
Python实现FLV视频拼接功能
Jan 21 #Python
TFRecord格式存储数据与队列读取实例
Jan 21 #Python
TensorFlow dataset.shuffle、batch、repeat的使用详解
Jan 21 #Python
使用 tf.nn.dynamic_rnn 展开时间维度方式
Jan 21 #Python
You might like
微盾PHP脚本加密专家php解密算法
2020/09/13 PHP
php地址引用(php地址引用的效率问题)
2012/03/23 PHP
深入PHP empty(),isset(),is_null()的实例测试详解
2013/06/06 PHP
PHP获取一个字符串中间一部分字符的方法
2014/08/19 PHP
基于linnux+phantomjs实现生成图片格式的网页快照
2015/04/15 PHP
PHP仿tp实现mvc框架基本设计思路与实现方法分析
2018/05/23 PHP
微信公众号开发之获取位置信息php代码
2018/06/13 PHP
文本加密解密
2006/06/23 Javascript
js 分页全选或反选标识实现代码
2011/08/09 Javascript
jQuery 过滤not()与filter()实例代码
2012/05/10 Javascript
JavaScript中的eval()函数详解
2013/08/22 Javascript
node.js中的fs.truncate方法使用说明
2014/12/15 Javascript
echarts实现地图定时切换散点与多图表级联联动详解
2018/08/07 Javascript
webpack 开发和生产并行设置的方法
2018/11/08 Javascript
JavaScript 中 JSON.parse 函数 和 JSON.stringify 函数
2018/12/05 Javascript
用原生 JS 实现 innerHTML 功能实例详解
2019/04/03 Javascript
vue插槽slot的理解和使用方法
2019/04/03 Javascript
详解jenkins自动化部署vue
2019/05/14 Javascript
微信小程序实现打卡签到页面
2020/09/21 Javascript
win与linux系统中python requests 安装
2016/12/04 Python
python如何使用unittest测试接口
2018/04/04 Python
从运行效率与开发效率比较Python和C++
2018/12/14 Python
Python logging模块handlers用法详解
2020/08/14 Python
详解Python中的路径问题
2020/09/02 Python
详解win10下pytorch-gpu安装以及CUDA详细安装过程
2021/01/28 Python
购买英国原创艺术:Art Gallery
2018/08/25 全球购物
优秀毕业生求职推荐信范文
2013/11/21 职场文书
学校学雷锋活动总结
2014/06/26 职场文书
市场策划求职信
2014/08/07 职场文书
2014年调度员工作总结
2014/11/19 职场文书
2014年煤矿工人工作总结
2014/12/08 职场文书
一般纳税人申请报告
2015/05/18 职场文书
招商银行收入证明
2015/06/17 职场文书
MySQL数据库中varchar类型的数字比较大小的方法
2021/11/17 MySQL
Java 定时任务技术趋势简介
2022/05/04 Java/Android
JS实现九宫格拼图游戏
2022/06/28 Javascript