tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中一些自然语言工具的使用的入门教程
Apr 13 Python
Python内置函数reversed()用法分析
Mar 20 Python
Python使用cx_Oracle模块操作Oracle数据库详解
May 07 Python
python中时间模块的基本使用教程
May 14 Python
python使用pygame模块实现坦克大战游戏
Mar 25 Python
selenium+PhantomJS爬取豆瓣读书
Aug 26 Python
Python字典底层实现原理详解
Dec 18 Python
Python递归及尾递归优化操作实例分析
Feb 01 Python
Django models文件模型变更错误解决
May 11 Python
python实现126邮箱发送邮件
May 20 Python
详解Python openpyxl库的基本应用
Feb 26 Python
Python基础之进程详解
May 21 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
文件上传程序的全部源码
2006/10/09 PHP
php类
2006/11/27 PHP
PHP使用GD库制作验证码的方法(点击验证码或看不清会刷新验证码)
2017/08/15 PHP
php对象工厂类完整示例
2018/08/09 PHP
File文件控件,选中文件(图片,flash,视频)即立即预览显示
2009/04/09 Javascript
犀利的js 函数集合
2009/06/11 Javascript
jquery.messager.js插件导致页面抖动的解决方法
2013/07/14 Javascript
js图片实时加载提供网页打开速度
2014/09/11 Javascript
Nodejs Post请求报socket hang up错误的解决办法
2014/09/25 NodeJs
JavaScript简单修改窗口大小的方法
2015/08/03 Javascript
从重置input file标签中看jQuery的 .val() 和 .attr(“value”) 区别
2016/06/12 Javascript
jQuery如何获取动态添加的元素
2016/06/24 Javascript
Vuejs第十篇之vuejs父子组件通信
2016/09/06 Javascript
Node.js开发教程之基于OnceIO框架实现文件上传和验证功能
2016/11/30 Javascript
BootStrap整体框架之基础布局组件
2016/12/15 Javascript
Nodejs读取文件时相对路径的正确写法(使用fs模块)
2017/04/27 NodeJs
vue子父组件通信的实现代码
2017/07/09 Javascript
小程序自定义日历效果
2018/12/29 Javascript
JavaScript学习笔记之DOM操作实例分析
2019/01/08 Javascript
如何实现小程序tab栏下划线动画效果
2019/05/18 Javascript
详解wepy开发小程序踩过的坑(小结)
2019/05/22 Javascript
JavaScript使用prototype属性实现继承操作示例
2020/05/22 Javascript
微信小程序实现星星评分效果
2020/11/01 Javascript
python实现的一个p2p文件传输实例
2014/06/04 Python
Python魔术方法详解
2015/02/14 Python
Python中在for循环中嵌套使用if和else语句的技巧
2016/06/20 Python
Python自定义线程类简单示例
2018/03/23 Python
python3利用tcp实现文件夹远程传输
2018/07/28 Python
python简单实现矩阵的乘,加,转置和逆运算示例
2019/07/10 Python
解决Django Haystack全文检索为空的问题
2020/05/19 Python
Keras 在fit_generator训练方式中加入图像random_crop操作
2020/07/03 Python
联想哥伦比亚网上商城:Lenovo Colombia
2017/01/10 全球购物
lululemon美国官网:瑜伽服+跑步装备
2018/11/16 全球购物
文明宿舍获奖感言
2014/02/07 职场文书
市场营销专业应届生自荐信
2014/06/19 职场文书
【海涛dota】偶遇拉娜娅 质量局德鲁伊第一视角解说
2022/04/01 DOTA