TensorFlow实现随机训练和批量训练的方法


Posted in Python onApril 28, 2018

TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。

为了TensorFlow计算变量梯度来让反向传播工作,我们必须度量一个或者多个样本的损失。

随机训练会一次随机抽样训练数据和目标数据对完成训练。另外一个可选项是,一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。这里将显示如何扩展前面的回归算法的例子——使用随机训练和批量训练。

批量训练和随机训练的不同之处在于它们的优化器方法和收敛。

# 随机训练和批量训练
#----------------------------------
#
# This python function illustrates two different training methods:
# batch and stochastic training. For each model, we will use
# a regression model that predicts one model variable.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 随机训练:
# Create graph
sess = tf.Session()

# 声明数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加操作到图
my_output = tf.multiply(x_data, A)

# 增加L2损失函数
loss = tf.square(my_output - y_target)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_stochastic = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100)
 rand_x = [x_vals[rand_index]]
 rand_y = [y_vals[rand_index]]
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_stochastic.append(temp_loss)


# 批量训练:
# 重置计算图
ops.reset_default_graph()
sess = tf.Session()

# 声明批量大小
# 批量大小是指通过计算图一次传入多少训练数据
batch_size = 20

# 声明模型的数据、占位符
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加矩阵乘法操作(矩阵乘法不满足交换律)
my_output = tf.matmul(x_data, A)

# 增加损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_batch = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100, size=batch_size)
 rand_x = np.transpose([x_vals[rand_index]])
 rand_y = np.transpose([y_vals[rand_index]])
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_batch.append(temp_loss)

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出:

Step #5 A = [ 1.47604525]
Loss = [ 72.55678558]
Step #10 A = [ 3.01128507]
Loss = [ 48.22986221]
Step #15 A = [ 4.27042341]
Loss = [ 28.97912598]
Step #20 A = [ 5.2984333]
Loss = [ 16.44779968]
Step #25 A = [ 6.17473984]
Loss = [ 16.373312]
Step #30 A = [ 6.89866304]
Loss = [ 11.71054649]
Step #35 A = [ 7.39849901]
Loss = [ 6.42773056]
Step #40 A = [ 7.84618378]
Loss = [ 5.92940331]
Step #45 A = [ 8.15709782]
Loss = [ 0.2142024]
Step #50 A = [ 8.54818344]
Loss = [ 7.11651039]
Step #55 A = [ 8.82354641]
Loss = [ 1.47823763]
Step #60 A = [ 9.07896614]
Loss = [ 3.08244276]
Step #65 A = [ 9.24868107]
Loss = [ 0.01143846]
Step #70 A = [ 9.36772251]
Loss = [ 2.10078788]
Step #75 A = [ 9.49171734]
Loss = [ 3.90913701]
Step #80 A = [ 9.6622715]
Loss = [ 4.80727625]
Step #85 A = [ 9.73786926]
Loss = [ 0.39915398]
Step #90 A = [ 9.81853104]
Loss = [ 0.14876099]
Step #95 A = [ 9.90371323]
Loss = [ 0.01657014]
Step #100 A = [ 9.86669159]
Loss = [ 0.444787]
Step #5 A = [[ 2.34371352]]
Loss = 58.766
Step #10 A = [[ 3.74766445]]
Loss = 38.4875
Step #15 A = [[ 4.88928795]]
Loss = 27.5632
Step #20 A = [[ 5.82038736]]
Loss = 17.9523
Step #25 A = [[ 6.58999157]]
Loss = 13.3245
Step #30 A = [[ 7.20851326]]
Loss = 8.68099
Step #35 A = [[ 7.71694899]]
Loss = 4.60659
Step #40 A = [[ 8.1296711]]
Loss = 4.70107
Step #45 A = [[ 8.47107315]]
Loss = 3.28318
Step #50 A = [[ 8.74283409]]
Loss = 1.99057
Step #55 A = [[ 8.98811722]]
Loss = 2.66906
Step #60 A = [[ 9.18062305]]
Loss = 3.26207
Step #65 A = [[ 9.31655025]]
Loss = 2.55459
Step #70 A = [[ 9.43130589]]
Loss = 1.95839
Step #75 A = [[ 9.55670166]]
Loss = 1.46504
Step #80 A = [[ 9.6354847]]
Loss = 1.49021
Step #85 A = [[ 9.73470974]]
Loss = 1.53289
Step #90 A = [[ 9.77956581]]
Loss = 1.52173
Step #95 A = [[ 9.83666706]]
Loss = 0.819207
Step #100 A = [[ 9.85569191]]
Loss = 1.2197

TensorFlow实现随机训练和批量训练的方法

训练类型 优点 缺点
随机训练 脱离局部最小 一般需更多次迭代才收敛
批量训练 快速得到最小损失 耗费更多计算资源

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python3中初学者应会的一些基本的提升效率的小技巧
Mar 31 Python
实例探究Python以并发方式编写高性能端口扫描器的方法
Jun 14 Python
老生常谈Python基础之字符编码
Jun 14 Python
python编写朴素贝叶斯用于文本分类
Dec 21 Python
设置python3为默认python的方法
Oct 31 Python
python生成以及打开json、csv和txt文件的实例
Nov 16 Python
python实现一个简单的ping工具方法
Jan 31 Python
Python有参函数使用代码实例
Jan 06 Python
Pymysql实现往表中插入数据过程解析
Jun 02 Python
Python 必须了解的5种高级特征
Sep 10 Python
关于Python 解决Python3.9 pandas.read_excel(‘xxx.xlsx‘)报错的问题
Nov 28 Python
python 实现图片批量压缩的示例
Dec 18 Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
You might like
php数组键值用法实例分析
2015/02/27 PHP
PHP实现防盗链的方法分析
2017/07/25 PHP
laravel中的fillable和guarded属性详解
2019/10/23 PHP
新浪刚打开页面出来的全屏广告代码
2007/04/02 Javascript
javascript 学习笔记(onchange等)
2010/11/14 Javascript
jQuery EasyUI API 中文文档 - ProgressBar 进度条
2011/09/29 Javascript
JQuery拖拽元素改变大小尺寸实现代码
2012/12/10 Javascript
javascript 禁用IE工具栏,导航栏等等实现代码
2013/04/01 Javascript
JavaScript中的console.assert()函数介绍
2014/12/29 Javascript
Javascript中获取对象的原型对象的方法小结
2015/02/25 Javascript
JQuery select(下拉框)操作方法汇总
2015/04/15 Javascript
javascript制作的滑动图片菜单
2015/05/15 Javascript
JS+JSP通过img标签调用实现静态页面访问次数统计的方法
2015/12/14 Javascript
javascript图片延迟加载实现方法及思路
2015/12/31 Javascript
jQuery绑定事件的几种实现方式
2016/05/09 Javascript
精彩的Bootstrap案例分享 重点在注释!(选项卡、栅格布局)
2016/07/01 Javascript
jQuery中animate()的使用方法及解决$(”body“).animate({“scrollTop”:top})不被Firefox支持的问题
2017/04/04 jQuery
Bootstrap fileinput文件上传组件使用详解
2017/06/06 Javascript
浅谈Webpack 是如何加载模块的
2018/05/24 Javascript
JavaScript事件发布/订阅模式原理与用法分析
2018/08/21 Javascript
Vue 中的受控与非受控组件的实现
2018/12/17 Javascript
跟老齐学Python之编写类之一创建实例
2014/10/11 Python
python图形绘制奥运五环实例讲解
2019/09/14 Python
tensorflow没有output结点,存储成pb文件的例子
2020/01/04 Python
Python中关于logging模块的学习笔记
2020/06/03 Python
Python常用base64 md5 aes des crc32加密解密方法汇总
2020/11/06 Python
用python制作个视频下载器
2021/02/01 Python
OPPO手机官方商城:中国手机市场出货量第一品牌
2017/10/18 全球购物
教师试用期自我鉴定
2014/02/12 职场文书
运动会演讲稿
2014/05/07 职场文书
大学生作弊检讨书
2014/09/11 职场文书
乡镇党的群众路线教育实践活动个人对照检查材料
2014/09/23 职场文书
2015年普法依法治理工作总结
2015/05/26 职场文书
爱国主义电影观后感
2015/06/18 职场文书
个人售房合同协议书
2016/03/21 职场文书
详解PHP服务器如何在有限的资源里最大提升并发能力
2021/05/25 PHP