tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Linux下使用python自动修改本机网关代码分享
May 21 Python
Python基于smtplib实现异步发送邮件服务
May 28 Python
python实现计算倒数的方法
Jul 11 Python
python3抓取中文网页的方法
Jul 28 Python
Python实现获取nginx服务器ip及流量统计信息功能示例
May 18 Python
Flask核心机制之上下文源码剖析
Dec 25 Python
Python的log日志功能及设置方法
Jul 11 Python
pycharm的python_stubs问题
Apr 08 Python
Python request操作步骤及代码实例
Apr 13 Python
python 基于wx实现音乐播放
Nov 24 Python
Python异常类型以及处理方法汇总
Jun 05 Python
用PYTHON去计算88键钢琴的琴键频率和音高
Apr 10 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
使用php4加速网络传输
2006/10/09 PHP
PHP实现微信公众平台音乐点播
2014/03/20 PHP
PHP Hash算法:Times33算法代码实例
2015/05/13 PHP
[原创]PHP实现字节数Byte转换为KB、MB、GB、TB的方法
2017/08/31 PHP
Laravel5.5+ 使用API Resources快速输出自定义JSON方法详解
2020/04/06 PHP
Prototype Class对象学习
2009/07/19 Javascript
javascript中的self和this用法小结
2014/02/08 Javascript
javascript 获取iframe里页面中元素值的方法
2014/02/17 Javascript
jQuery操作元素css样式的三种方法
2014/06/04 Javascript
Javascript函数的参数
2015/07/16 Javascript
JS产生随机数的几个用法详解
2016/06/22 Javascript
js防阻塞加载的实现方法
2016/09/09 Javascript
简单学习vue指令directive
2016/11/03 Javascript
vue中用动态组件实现选项卡切换效果
2017/03/25 Javascript
vue2.0实战之基础入门(1)
2017/03/27 Javascript
jQuery遍历节点方法汇总(推荐)
2017/05/13 jQuery
详解解决使用axios发送json后台接收不到的问题
2018/06/27 Javascript
谈谈为什么你的 JavaScript 代码如此冗长
2019/01/30 Javascript
Javascript表单序列化原理及实现代码详解
2020/10/30 Javascript
Python多线程编程(四):使用Lock互斥锁
2015/04/05 Python
python实现一次创建多级目录的方法
2015/05/15 Python
python3中int(整型)的使用教程
2017/03/23 Python
python xlsxwriter库生成图表的应用示例
2018/03/16 Python
使用Python测试Ping主机IP和某端口是否开放的实例
2019/12/17 Python
Jupyter notebook如何实现指定浏览器打开
2020/05/13 Python
Python爬虫如何破解JS加密的Cookie
2020/11/19 Python
世界闻名的衬衫制造商:Savile Row Company
2018/07/30 全球购物
用C#语言写出在本地创建一个UDP接收端口的具体过程
2016/02/22 面试题
高中毕业自我鉴定
2013/12/16 职场文书
优秀班集体获奖感言
2014/02/03 职场文书
《蝙蝠和雷达》教学反思
2014/04/23 职场文书
2015年部门工作总结范文
2015/03/31 职场文书
2015年党务公开工作总结
2015/05/19 职场文书
追讨欠款律师函
2015/06/24 职场文书
《大禹治水》教学反思
2016/02/22 职场文书
​(迎国庆)作文之我爱我的祖国
2019/09/19 职场文书