tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的Matplotlib模块入门教程
Apr 15 Python
Python使用Scrapy爬取妹子图
May 28 Python
Tornado协程在python2.7如何返回值(实现方法)
Jun 22 Python
python实现闹钟定时播放音乐功能
Jan 25 Python
python实现在图片上画特定大小角度矩形框
Oct 24 Python
详解Ubuntu16.04安装Python3.7及其pip3并切换为默认版本
Feb 25 Python
Django REST framework 视图和路由详解
Jul 19 Python
Python爬取豆瓣视频信息代码实例
Nov 16 Python
python3 dict ndarray 存成json,并保留原数据精度的实例
Dec 06 Python
pycharm如何实现跨目录调用文件
Feb 28 Python
Python 将 QQ 好友头像生成祝福语的实现代码
May 03 Python
Python利用Faiss库实现ANN近邻搜索的方法详解
Aug 03 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
跟我学Laravel之配置Laravel
2014/10/15 PHP
ThinkPHP框架中使用Memcached缓存数据的方法
2018/03/31 PHP
一段效率很高的for循环语句使用方法
2007/08/13 Javascript
extjs fckeditor集成代码
2009/05/10 Javascript
Raphael一个用于在网页中绘制矢量图形的Javascript库
2013/01/08 Javascript
jQuery.prototype.init选择器构造函数源码思路分析
2013/02/05 Javascript
jquery仿京东导航/仿淘宝商城左侧分类导航下拉菜单效果
2013/04/24 Javascript
快速解决FusionCharts联动的中文乱码问题
2013/12/04 Javascript
jquery对table中各数据的增加、保存、删除操作示例
2014/05/14 Javascript
JQuery动画与特效实例分析
2015/02/02 Javascript
vue实现可增删查改的成绩单
2016/10/27 Javascript
jQuery简单实现MD5加密的方法
2017/03/03 Javascript
async/await让异步操作同步执行的方法详解
2019/11/01 Javascript
基于VUE实现简单的学生信息管理系统
2021/01/13 Vue.js
Python实现的检测网站挂马程序
2014/11/30 Python
简单介绍Python中的RSS处理
2015/04/13 Python
Python isinstance函数介绍
2015/04/14 Python
Python计算三角函数之asin()方法的使用
2015/05/15 Python
Python处理字符串之isspace()方法的使用
2015/05/19 Python
Python编程之变量赋值操作实例分析
2017/07/24 Python
Python实现的科学计算器功能示例
2017/08/04 Python
Python使用functools实现注解同步方法
2018/02/06 Python
Python 多维List创建的问题小结
2019/01/18 Python
python如果快速判断数字奇数偶数
2019/11/13 Python
关于pytorch中全连接神经网络搭建两种模式详解
2020/01/14 Python
CSS3的Flexbox布局的简明入门指南
2016/04/08 HTML / CSS
Manuka Doctor英国官网:真正的麦卢卡蜂蜜和护肤品
2018/10/26 全球购物
事业单位竞聘上岗实施方案
2014/03/28 职场文书
运动会广播稿50字-100字
2014/10/11 职场文书
2014年客房服务员工作总结
2014/11/18 职场文书
2014年销售工作总结与计划
2014/12/01 职场文书
离婚起诉书怎么写
2015/05/19 职场文书
员工考勤管理制度
2015/08/06 职场文书
关于公司年会的开幕词
2016/03/04 职场文书
Python一些基本的图像操作和处理总结
2021/06/23 Python
HTML5基础学习之文本标签控制
2022/03/25 HTML / CSS