tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python多线程下载文件的方法
Jul 10 Python
Python实现批量更换指定目录下文件扩展名的方法
Sep 19 Python
Python实现定时任务
Feb 08 Python
Python+Socket实现基于TCP协议的客户与服务端中文自动回复聊天功能示例
Aug 31 Python
python微信跳一跳系列之棋子定位像素遍历
Feb 26 Python
Python数据类型之列表和元组的方法实例详解
Jul 08 Python
python多线程同步之文件读写控制
Feb 25 Python
python实现简单成绩录入系统
Sep 19 Python
简单了解python字符串前面加r,u的含义
Dec 26 Python
Selenium环境变量配置(火狐浏览器)及验证实现
Dec 07 Python
python如何进行基准测试
Apr 26 Python
利用For循环遍历Python字典的三种方法实例
Mar 25 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
php启动时候提示PHP startup的解决方法
2013/05/07 PHP
php实现文件下载功能的几个代码分享
2014/05/10 PHP
自己写的php curl库实现整站克隆功能
2015/02/12 PHP
php解决安全问题的方法实例
2019/09/19 PHP
防止网站内容被拷贝的一些方法与优缺点好处与坏处分析
2007/11/30 Javascript
Tips 带三角可关闭的文字提示
2010/10/06 Javascript
说说JSON和JSONP 也许你会豁然开朗
2012/09/02 Javascript
如何将JS的变量值传递给ASP变量
2012/12/10 Javascript
node.js中的querystring.stringify方法使用说明
2014/12/10 Javascript
同一个网页中实现多个JavaScript特效的方法
2015/02/02 Javascript
用Angular实时获取本地Localstorage数据,实现一个模拟后台数据登入的效果
2016/11/09 Javascript
解决bootstrap模态框数据缓存的问题方法
2018/08/10 Javascript
vue-for循环嵌套操作示例
2019/01/28 Javascript
说说如何使用Vuex进行状态管理(小结)
2019/04/14 Javascript
Vue源码解析之数据响应系统的使用
2019/04/24 Javascript
JavaScript实现图片放大镜效果
2019/06/27 Javascript
精确查找PHP WEBSHELL木马的方法(1)
2011/04/12 Python
Python使用Flask框架同时上传多个文件的方法
2015/03/21 Python
Python实现的圆形绘制(画圆)示例
2018/01/31 Python
基于python实现聊天室程序
2018/07/27 Python
Linux下python3.6.1环境配置教程
2018/09/26 Python
Python之循环结构
2019/01/15 Python
python实现nao机器人手臂动作控制
2019/04/29 Python
pyqt5对用qt designer设计的窗体实现弹出子窗口的示例
2019/06/19 Python
Django调用百度AI接口实现人脸注册登录代码实例
2020/04/23 Python
CSS3条纹背景制作的实战攻略
2016/05/31 HTML / CSS
Canon佳能美国官方商店:购买数码相机、数码单反相机、镜头和打印机
2016/11/15 全球购物
声明struct x1 { . . . }; 和typedef struct { . . . }x2;有什么不同
2012/06/02 面试题
饮料业务员岗位职责
2013/12/15 职场文书
铲车司机岗位职责
2014/03/15 职场文书
2014年党员公开承诺书范文
2014/03/28 职场文书
媒体宣传策划方案
2014/05/25 职场文书
法人委托书的范本格式
2014/09/11 职场文书
2014年幼儿园小班工作总结
2014/12/04 职场文书
MYSQL优化之数据表碎片整理详解
2022/04/03 MySQL
Python可视化神器pyecharts之绘制箱形图
2022/07/07 Python