tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python del()函数用法
Mar 24 Python
讲解Python中fileno()方法的使用
May 24 Python
python中将正则过滤的内容输出写入到文件中的实例
Oct 21 Python
Python面向对象程序设计之私有属性及私有方法示例
Apr 08 Python
Django Admin中增加导出CSV功能过程解析
Sep 04 Python
python数据预处理方式 :数据降维
Feb 24 Python
python如何停止递归
Sep 09 Python
Python爬取豆瓣数据实现过程解析
Oct 27 Python
使paramiko库执行命令时在给定的时间强制退出功能的实现
Mar 03 Python
Python中for后接else的语法使用
May 18 Python
Python下opencv库的安装过程及问题汇总
Jun 11 Python
利用For循环遍历Python字典的三种方法实例
Mar 25 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
保存到桌面、设为桌面且带图标的PHP代码
2013/11/19 PHP
PHP session文件独占锁引起阻塞问题解决方法
2015/05/12 PHP
php根据年月获取当月天数及日期数组的方法
2016/11/30 PHP
基于swoole实现多人聊天室
2018/06/14 PHP
phpstudy隐藏index.php的方法
2020/09/21 PHP
TNC vs IO BO3 第一场2.13
2021/03/10 DOTA
ExtJS 2.0实用简明教程之应用ExtJS
2009/04/29 Javascript
JavaScript Memoization 让函数也有记忆功能
2011/10/27 Javascript
jquery 缓存问题的几个解决方法
2013/11/11 Javascript
JS的document.all函数使用示例
2013/12/30 Javascript
基于Arcgis for javascript实现百度地图ABCD marker的效果
2015/09/12 Javascript
js格式化时间的简单实例
2016/11/27 Javascript
微信小程序 省市区选择器实例详解(附源码下载)
2017/01/05 Javascript
常用的javascript设计模式
2017/01/11 Javascript
layui点击按钮添加可编辑的一行方法
2018/08/15 Javascript
浅谈一种让小程序支持JSX语法的新思路
2019/06/16 Javascript
vue-router两种模式区别及使用注意事项详解
2019/08/01 Javascript
vue-router结合vuex实现用户权限控制功能
2019/11/14 Javascript
jQuery实现放大镜案例
2020/10/19 jQuery
vue实现两个组件之间数据共享和修改操作
2020/11/12 Javascript
详解Python开发中如何使用Hook技巧
2017/11/01 Python
python计算两个数的百分比方法
2018/06/29 Python
通过实例简单了解Python中yield的作用
2019/12/11 Python
Python字符串的15个基本操作(小结)
2021/02/03 Python
解决pytorch下出现multi-target not supported at的一种可能原因
2021/02/06 Python
HTML5拖放效果的实现代码
2016/11/17 HTML / CSS
巴西葡萄酒商店:Divvino
2020/02/22 全球购物
高级护理专业大学生求职信
2013/10/24 职场文书
医德医风自我评价
2014/09/19 职场文书
2014副镇长民主生活会个人对照检查材料思想汇报
2014/09/30 职场文书
自查自纠整改报告
2014/11/06 职场文书
银行资信证明
2015/06/17 职场文书
Python中OpenCV实现查找轮廓的实例
2021/06/08 Python
用Python编写简单的gRPC服务的详细过程
2021/07/04 Python
「月刊Action」2022年5月号封面公开
2022/03/21 日漫
怎么禁用Win11输入法 最新Win11输入法关闭教程
2022/08/05 数码科技