tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将MongoDB里的ObjectId转换为时间戳的方法
Mar 13 Python
python生成随机mac地址的方法
Mar 16 Python
Python对CSV、Excel、txt、dat文件的处理
Sep 18 Python
Windows下安装Scrapy
Oct 17 Python
python实现扫描局域网指定网段ip的方法
Apr 16 Python
python 列表中[ ]中冒号‘:’的作用
Apr 30 Python
pyqt5 实现工具栏文字图片同时显示
Jun 13 Python
Python创建一个元素都为0的列表实例
Nov 28 Python
基于python实现matlab filter函数过程详解
Jun 08 Python
Python代码注释规范代码实例解析
Aug 14 Python
浅析python函数式编程
Sep 26 Python
python进行二次方程式计算的实例讲解
Dec 06 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
php strnatcmp()函数的用法总结
2013/11/27 PHP
php格式化时间戳显示友好的时间实现思路及代码
2014/10/23 PHP
PHP实现的策略模式示例
2019/03/20 PHP
自己的js工具_Form 封装
2009/08/21 Javascript
超轻量级的基于jquery的三级展开列表
2011/04/26 Javascript
用js判断页面刷新或关闭的方法(onbeforeunload与onunload事件)
2012/06/22 Javascript
javascript弹出层输入框(示例代码)
2013/12/11 Javascript
javascript实现存储hmtl字符串示例
2014/04/25 Javascript
jquery处理json数据实例分析
2014/06/03 Javascript
jqueryMobile 动态添加元素,展示刷新视图的实现方法
2016/05/28 Javascript
Bootstrap标签页(Tab)插件使用方法
2017/03/21 Javascript
JS+HTML5实现图片在线预览功能
2017/07/22 Javascript
浅谈在vue中使用mint-ui swipe遇到的问题
2018/09/27 Javascript
Vue中JS动画与Velocity.js的结合使用
2019/02/13 Javascript
微信小程序wepy框架学习和使用心得详解
2019/05/24 Javascript
Vue-router 报错NavigationDuplicated的解决方法
2020/03/31 Javascript
vue动态加载SVG文件并修改节点数据的操作代码
2020/08/17 Javascript
[07:55]2014DOTA2 TI正赛第三日 VG上演推进荣耀DKEG告别
2014/07/21 DOTA
python Flask 装饰器顺序问题解决
2018/08/08 Python
python基于socket进行端口转发实现后门隐藏的示例
2019/07/25 Python
Python迭代器模块itertools使用原理解析
2019/12/11 Python
Python解释器以及PyCharm的安装教程图文详解
2020/02/26 Python
Django 再谈一谈json序列化
2020/03/16 Python
Python内置函数及功能简介汇总
2020/10/13 Python
python3实现语音转文字(语音识别)和文字转语音(语音合成)
2020/10/14 Python
使用HTML5 Canvas绘制圆角矩形及相关的一些应用举例
2016/03/22 HTML / CSS
Android面试题及答案
2015/09/04 面试题
干部下基层实施方案
2014/03/14 职场文书
做人民满意的公务员活动方案
2014/08/25 职场文书
2014年终工作总结范本
2014/12/15 职场文书
孩子满月酒答谢词
2015/09/30 职场文书
导游词之桂林山水
2019/09/20 职场文书
python函数指定默认值的实例讲解
2021/03/29 Python
Python机器学习之基础概述
2021/05/19 Python
MySQL里面的子查询的基本使用
2021/08/02 MySQL
Spring 使用注解开发
2022/05/20 Java/Android