tensorflow模型继续训练 fineturn实例


Posted in Python onJanuary 21, 2020

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
 
# 操作
result = tf.matmul(x, W) + b
 
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
 
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
 
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

tensorflow模型继续训练 fineturn实例

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
 
 
final result of x×W+b = [[99.99978]](目标值是100.0)
 
模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf
 
 
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
 
with tf.Session() as sess:
 
  # 初始化变量
  sess.run(tf.global_variables_initializer())
 
  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
 
  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")
 
  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
 
  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
 
  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]
 
  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
 
 
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解python中的atexit模块
Mar 07 Python
Pycharm学习教程(3) 代码运行调试
May 03 Python
Python入门_浅谈数据结构的4种基本类型
May 16 Python
Python基于回溯法子集树模板解决野人与传教士问题示例
Sep 11 Python
Python基础语言学习笔记总结(精华)
Nov 14 Python
python实现神经网络感知器算法
Dec 20 Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 Python
Django model序列化为json的方法示例
Oct 16 Python
Python数据可视化库seaborn的使用总结
Jan 15 Python
tensorflow实现读取模型中保存的值 tf.train.NewCheckpointReader
Feb 10 Python
使用PyQt5实现图片查看器的示例代码
Apr 21 Python
python使用Thread的setDaemon启动后台线程教程
Apr 25 Python
tensorflow ckpt模型和pb模型获取节点名称,及ckpt转pb模型实例
Jan 21 #Python
tensorflow查看ckpt各节点名称实例
Jan 21 #Python
python同义词替换的实现(jieba分词)
Jan 21 #Python
tensorflow模型保存、加载之变量重命名实例
Jan 21 #Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 #Python
tensorflow如何继续训练之前保存的模型实例
Jan 21 #Python
在tensorflow中设置保存checkpoint的最大数量实例
Jan 21 #Python
You might like
php获取文件内容最后一行示例
2014/01/09 PHP
PHP实现的链式队列结构示例
2017/09/15 PHP
artDialog双击会关闭对话框的修改过程分享
2013/08/05 Javascript
jQuery将所有被选中的checkbox某个属性值连接成字符串的方法
2015/01/24 Javascript
使用Meteor配合Node.js编写实时聊天应用的范例
2015/06/23 Javascript
学习Angularjs分页指令
2016/07/01 Javascript
微信小程序 简单教程实例详解
2017/01/13 Javascript
jQuery EasyUI Accordion可伸缩面板组件使用详解
2017/02/28 Javascript
Node.js查找当前目录下文件夹实例代码
2017/03/07 Javascript
vue.js 实现点击展开收起动画效果
2018/07/07 Javascript
微信小程序自定义组件封装及父子间组件传值的方法
2018/08/28 Javascript
如何通过shell脚本自动生成vue文件详解
2019/09/10 Javascript
JavaScript实现轮播图片完整代码
2020/03/07 Javascript
jQuery+ajax实现文件上传功能
2020/12/22 jQuery
使用Mixin设计模式进行Python编程的方法讲解
2016/06/21 Python
让python 3支持mysqldb的解决方法
2017/02/14 Python
Python 40行代码实现人脸识别功能
2017/04/02 Python
微信跳一跳python辅助脚本(总结)
2018/01/11 Python
Python反转序列的方法实例分析
2018/03/21 Python
python制作图片缩略图
2019/04/30 Python
Python音频操作工具PyAudio上手教程详解
2019/06/26 Python
基于python实现语音录入识别代码实例
2020/01/17 Python
基于Python3.6中的OpenCV实现图片色彩空间的转换
2020/02/03 Python
python yield和Generator函数用法详解
2020/02/10 Python
Python reduce函数作用及实例解析
2020/05/08 Python
Python性能分析工具py-spy原理用法解析
2020/07/27 Python
python热力图实现简单方法
2021/01/29 Python
购买美国制造的相框和画框架:Picture Frames
2018/08/14 全球购物
e路東瀛(JAPANiCAN)香港:日本旅游、日本酒店和温泉旅馆预订
2018/11/21 全球购物
家庭教育先进个人事迹材料
2014/01/24 职场文书
优秀员工推荐信
2014/05/10 职场文书
升学宴演讲稿
2014/09/01 职场文书
颂军魂爱军营演讲稿
2014/09/13 职场文书
责任书格式
2015/01/29 职场文书
2016党员发展对象培训心得体会
2016/01/08 职场文书
十大最强妖精系宝可梦,哲尔尼亚斯实力最强,第五被称为大力士
2022/03/18 日漫