tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获得linux下所有挂载点(mount points)的方法
Apr 29 Python
Django数据库操作的实例(增删改查)
Sep 04 Python
一篇文章读懂Python赋值与拷贝
Apr 19 Python
python实现对指定输入的字符串逆序输出的6种方法
Apr 26 Python
python中使用iterrows()对dataframe进行遍历的实例
Jun 09 Python
numpy使用fromstring创建矩阵的实例
Jun 15 Python
将tensorflow的ckpt模型存储为npy的实例
Jul 09 Python
Python将8位的图片转为24位的图片实现方法
Oct 24 Python
python中设置超时跳过,超时退出的方式
Dec 13 Python
用python实现前向分词最大匹配算法的示例代码
Aug 06 Python
python+appium+yaml移动端自动化测试框架实现详解
Nov 24 Python
使用Python通过oBIX协议访问Niagara数据的示例
Dec 04 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
php下HTTP Response中的Chunked编码实现方法
2008/11/19 PHP
php 数据库字段复用的基本原理与示例
2011/07/22 PHP
php判断变量类型常用方法
2012/04/24 PHP
php正则表达式使用的详细介绍
2013/04/27 PHP
解析argc argv在php中的应用
2013/06/24 PHP
浅谈PHP中的那些魔术常量
2020/12/02 PHP
经典的解除许多网站无法复制文字的绝招
2006/12/31 Javascript
使用EXT实现无刷新动态调用股票信息
2008/11/01 Javascript
js 面向对象的技术创建高级 Web 应用程序
2010/02/25 Javascript
jquery.validate使用攻略 第五步 正则验证
2010/07/01 Javascript
解决IE6的PNG透明JS插件使用介绍
2013/04/17 Javascript
JS动态添加与删除select中的Option对象(示例代码)
2013/12/25 Javascript
JavaScript字符串对象的concat方法实例(用于连接两个或多个字符串)
2014/10/16 Javascript
jQuery实现动态添加、删除按钮及input输入框的方法
2017/04/27 jQuery
Vue2.0表单校验组件vee-validate的使用详解
2017/05/02 Javascript
AngularJS实现的根据数量与单价计算总价功能示例
2017/12/26 Javascript
vue实现的上传图片到数据库并显示到页面功能示例
2018/03/17 Javascript
jQuery ajax仿Google自动提示SearchSuggess功能示例
2019/03/28 jQuery
JavaScript 严格模式(use strict)用法实例分析
2020/03/04 Javascript
js实现从右往左匀速显示图片(无缝轮播)
2020/06/29 Javascript
Jquery使用each函数实现遍历及数组处理
2020/07/14 jQuery
[03:11]DOTA2上海特锦赛小组赛第一日recap精彩回顾
2016/02/28 DOTA
Python获取邮件地址的方法
2015/07/10 Python
如何获取Python简单for循环索引
2019/11/21 Python
python 图像判断,清晰度(明暗),彩色与黑白实例
2020/06/04 Python
PyCharm中配置PySide2的图文教程
2020/06/18 Python
CSS3 please 跨浏览器的CSS3产生器
2010/03/14 HTML / CSS
Needle & Thread官网:英国仙女品牌
2018/01/13 全球购物
加拿大国民体育购物网站:National Sports
2018/11/04 全球购物
英文版销售经理个人求职信
2013/11/20 职场文书
秋天的图画教学反思
2014/05/01 职场文书
学生考试舞弊检讨书
2015/01/01 职场文书
高中运动会前导词
2015/07/20 职场文书
特种设备安全管理制度
2015/08/06 职场文书
利用python Pandas实现批量拆分Excel与合并Excel
2021/05/23 Python
element tree树形组件回显数据问题解决
2022/08/14 Javascript