tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python文件比较示例分享
Jan 10 Python
让python 3支持mysqldb的解决方法
Feb 14 Python
Python实现PS图像调整颜色梯度效果示例
Jan 25 Python
Python使用pip安装pySerial串口通讯模块
Apr 20 Python
解决Pycharm下面出现No R interpreter defined的问题
Oct 29 Python
如何使用python操作vmware
Jul 27 Python
python with (as)语句实例详解
Feb 04 Python
python实现梯度下降和逻辑回归
Mar 24 Python
判断Threading.start新线程是否执行完毕的实例
May 02 Python
python继承threading.Thread实现有返回值的子类实例
May 02 Python
浅析Python面向对象编程
Jul 10 Python
pytorch训练神经网络爆内存的解决方案
May 22 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
索尼ICF-SW100收音机评测
2021/03/02 无线电
php empty,isset,is_null判断比较(差异与异同)
2010/10/19 PHP
一个经典的PHP验证码类分享
2014/11/18 PHP
提交表单后 PHP获取提交内容的实现方法
2016/05/25 PHP
js获取class的所有元素
2013/03/28 Javascript
深入领悟JavaScript中的面向对象
2013/11/18 Javascript
js 判断图片是否加载完以及实现图片的预下载
2014/08/14 Javascript
JavaScript常用的弹出广告及背投广告实现方法
2015/02/06 Javascript
javascript动态添加checkbox复选框的方法
2015/12/23 Javascript
Angular的Bootstrap(引导)和Compiler(编译)机制
2016/06/20 Javascript
js时间戳和c#时间戳互转方法(推荐)
2017/02/15 Javascript
Three.js获取鼠标点击的三维坐标示例代码
2017/03/24 Javascript
JS实现按钮颜色切换效果
2020/09/05 Javascript
jQuery图片加载失败替换默认图片方法汇总
2017/11/29 jQuery
详解vue-cli官方脚手架配置
2018/07/20 Javascript
详解vue.js移动端配置flexible.js及注意事项
2019/04/10 Javascript
layui 数据表格 点击分页按钮 监听事件的实例
2019/09/02 Javascript
使用Vue调取接口,并渲染数据的示例代码
2019/10/28 Javascript
node.js中fs文件系统模块的使用方法实例详解
2020/02/13 Javascript
微信小程序间使用navigator跳转传值问题实例分析
2020/03/27 Javascript
vue实现简单加法计算器
2020/10/22 Javascript
vue+element table表格实现动态列筛选的示例代码
2021/01/14 Vue.js
Tensorflow加载预训练模型和保存模型的实例
2018/07/27 Python
python读取txt文件,去掉空格计算每行长度的方法
2018/12/20 Python
Python类和对象的定义与实际应用案例分析
2018/12/27 Python
selenium+python截图不成功的解决方法
2019/01/30 Python
python操作docx写入内容,并控制文本的字体颜色
2020/02/13 Python
解决pycharm导入numpy包的和使用时报错:RuntimeError: The current Numpy installation (‘D:\\python3.6\\lib\\site-packa的问题
2020/12/08 Python
美国最大的珠宝商之一:Littman Jewelers
2016/11/13 全球购物
同步和异步有何异同,在什么情况下分别使用他们?举例说明
2014/02/27 面试题
企业口号大全
2014/06/12 职场文书
铣床操作工岗位职责
2014/06/13 职场文书
2014年学生会工作总结
2014/11/07 职场文书
完美解决golang go get私有仓库的问题
2021/05/05 Golang
php中配置文件保存修改操作 如config.php文件的读取修改等操作
2021/05/12 PHP
js作用域及作用域链工作引擎
2022/07/07 Javascript