tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python3 编写简单信用卡管理程序
Dec 21 Python
Django的分页器实例(paginator)
Dec 01 Python
python+opencv实现动态物体追踪
Jan 09 Python
基于python 二维数组及画图的实例详解
Apr 03 Python
基于python 爬虫爬到含空格的url的处理方法
May 11 Python
Python图像的增强处理操作示例【基于ImageEnhance类】
Jan 03 Python
python求numpy中array按列非零元素的平均值案例
Jun 08 Python
python 使用建议与技巧分享(四)
Aug 18 Python
python中的yield from语法快速学习
Nov 06 Python
Pytorch可视化的几种实现方法
Jun 10 Python
python脚本框架webpy的url映射详解
Nov 20 Python
分享3个非常实用的 Python 模块
Mar 03 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
一步一步学习PHP(5) 类和对象
2010/02/16 PHP
PHP四舍五入、取整、round函数使用示例
2015/02/06 PHP
javascript KeyDown、KeyPress和KeyUp事件的区别与联系
2009/12/03 Javascript
基于JQUERY的多级联动代码
2012/01/24 Javascript
可自己添加html的伪弹出框实现代码
2013/09/08 Javascript
Firefox和IE兼容性问题及解决方法总结
2013/10/08 Javascript
nodejs npm package.json中文文档
2014/09/04 NodeJs
node.js中使用socket.io制作命名空间
2014/12/15 Javascript
js获取元素外链样式的方法
2015/01/27 Javascript
JS制作简单的三级联动
2015/03/18 Javascript
jQuery插件HighCharts实现的2D面积图效果示例【附demo源码下载】
2017/03/15 Javascript
Javascript面试经典套路reduce函数查重
2017/03/23 Javascript
详解vue.js2.0父组件点击触发子组件方法
2017/05/10 Javascript
js实现文字列表无缝滚动效果
2017/06/23 Javascript
深入理解Vue transition源码分析
2017/07/30 Javascript
Vue二次封装axios为插件使用详解
2018/05/21 Javascript
JS中的算法与数据结构之集合(Set)实例详解
2019/08/20 Javascript
[54:25]Ti4 循环赛第三日LGD vs MOUZ
2014/07/12 DOTA
Python中文件遍历的两种方法
2014/06/16 Python
Python简直是万能的,这5大主要用途你一定要知道!(推荐)
2019/04/03 Python
python 实现返回一个列表中出现次数最多的元素方法
2019/06/11 Python
django项目环境搭建及在虚拟机本地创建django项目的教程
2019/08/02 Python
Python实现RGB与HSI颜色空间的互换方式
2019/11/27 Python
利用pandas向一个csv文件追加写入数据的实现示例
2020/04/23 Python
python制作微博图片爬取工具
2021/01/16 Python
Probikekit日本:自行车套件,跑步和铁人三项装备
2017/04/03 全球购物
《蜗牛》教学反思
2014/02/18 职场文书
管理标语大全
2014/06/24 职场文书
2014年党员整改措施范文
2014/09/21 职场文书
四年级学生期末评语
2014/12/26 职场文书
个人工作年终总结
2015/03/09 职场文书
企业战略合作意向书
2015/05/08 职场文书
导游词之扬州大明寺
2019/10/09 职场文书
Python爬虫爬取全球疫情数据并存储到mysql数据库的步骤
2021/03/29 Python
python实现简单倒计时功能
2021/04/21 Python
利用Redis实现点赞功能的示例代码
2022/06/28 Redis