TensorFlow实现随机训练和批量训练的方法


Posted in Python onApril 28, 2018

TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。

为了TensorFlow计算变量梯度来让反向传播工作,我们必须度量一个或者多个样本的损失。

随机训练会一次随机抽样训练数据和目标数据对完成训练。另外一个可选项是,一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。这里将显示如何扩展前面的回归算法的例子——使用随机训练和批量训练。

批量训练和随机训练的不同之处在于它们的优化器方法和收敛。

# 随机训练和批量训练
#----------------------------------
#
# This python function illustrates two different training methods:
# batch and stochastic training. For each model, we will use
# a regression model that predicts one model variable.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 随机训练:
# Create graph
sess = tf.Session()

# 声明数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加操作到图
my_output = tf.multiply(x_data, A)

# 增加L2损失函数
loss = tf.square(my_output - y_target)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_stochastic = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100)
 rand_x = [x_vals[rand_index]]
 rand_y = [y_vals[rand_index]]
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_stochastic.append(temp_loss)


# 批量训练:
# 重置计算图
ops.reset_default_graph()
sess = tf.Session()

# 声明批量大小
# 批量大小是指通过计算图一次传入多少训练数据
batch_size = 20

# 声明模型的数据、占位符
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加矩阵乘法操作(矩阵乘法不满足交换律)
my_output = tf.matmul(x_data, A)

# 增加损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_batch = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100, size=batch_size)
 rand_x = np.transpose([x_vals[rand_index]])
 rand_y = np.transpose([y_vals[rand_index]])
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_batch.append(temp_loss)

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出:

Step #5 A = [ 1.47604525]
Loss = [ 72.55678558]
Step #10 A = [ 3.01128507]
Loss = [ 48.22986221]
Step #15 A = [ 4.27042341]
Loss = [ 28.97912598]
Step #20 A = [ 5.2984333]
Loss = [ 16.44779968]
Step #25 A = [ 6.17473984]
Loss = [ 16.373312]
Step #30 A = [ 6.89866304]
Loss = [ 11.71054649]
Step #35 A = [ 7.39849901]
Loss = [ 6.42773056]
Step #40 A = [ 7.84618378]
Loss = [ 5.92940331]
Step #45 A = [ 8.15709782]
Loss = [ 0.2142024]
Step #50 A = [ 8.54818344]
Loss = [ 7.11651039]
Step #55 A = [ 8.82354641]
Loss = [ 1.47823763]
Step #60 A = [ 9.07896614]
Loss = [ 3.08244276]
Step #65 A = [ 9.24868107]
Loss = [ 0.01143846]
Step #70 A = [ 9.36772251]
Loss = [ 2.10078788]
Step #75 A = [ 9.49171734]
Loss = [ 3.90913701]
Step #80 A = [ 9.6622715]
Loss = [ 4.80727625]
Step #85 A = [ 9.73786926]
Loss = [ 0.39915398]
Step #90 A = [ 9.81853104]
Loss = [ 0.14876099]
Step #95 A = [ 9.90371323]
Loss = [ 0.01657014]
Step #100 A = [ 9.86669159]
Loss = [ 0.444787]
Step #5 A = [[ 2.34371352]]
Loss = 58.766
Step #10 A = [[ 3.74766445]]
Loss = 38.4875
Step #15 A = [[ 4.88928795]]
Loss = 27.5632
Step #20 A = [[ 5.82038736]]
Loss = 17.9523
Step #25 A = [[ 6.58999157]]
Loss = 13.3245
Step #30 A = [[ 7.20851326]]
Loss = 8.68099
Step #35 A = [[ 7.71694899]]
Loss = 4.60659
Step #40 A = [[ 8.1296711]]
Loss = 4.70107
Step #45 A = [[ 8.47107315]]
Loss = 3.28318
Step #50 A = [[ 8.74283409]]
Loss = 1.99057
Step #55 A = [[ 8.98811722]]
Loss = 2.66906
Step #60 A = [[ 9.18062305]]
Loss = 3.26207
Step #65 A = [[ 9.31655025]]
Loss = 2.55459
Step #70 A = [[ 9.43130589]]
Loss = 1.95839
Step #75 A = [[ 9.55670166]]
Loss = 1.46504
Step #80 A = [[ 9.6354847]]
Loss = 1.49021
Step #85 A = [[ 9.73470974]]
Loss = 1.53289
Step #90 A = [[ 9.77956581]]
Loss = 1.52173
Step #95 A = [[ 9.83666706]]
Loss = 0.819207
Step #100 A = [[ 9.85569191]]
Loss = 1.2197

TensorFlow实现随机训练和批量训练的方法

训练类型 优点 缺点
随机训练 脱离局部最小 一般需更多次迭代才收敛
批量训练 快速得到最小损失 耗费更多计算资源

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用range函数计算一组数和的方法
May 07 Python
python自动翻译实现方法
May 28 Python
Python中运算符"=="和"is"的详解
Oct 08 Python
python 读取视频,处理后,实时计算帧数fps的方法
Jul 10 Python
python实现时间o(1)的最小栈的实例代码
Jul 23 Python
详解django中使用定时任务的方法
Sep 27 Python
Python实现KNN(K-近邻)算法的示例代码
Mar 05 Python
使用Python opencv实现视频与图片的相互转换
Jul 08 Python
python语言中有算法吗
Jun 16 Python
安装pyecharts1.8.0版本后导入pyecharts模块绘图时报错: “所有图表类型将在 v1.9.0 版本开始强制使用 ChartItem 进行数据项配置 ”的解决方法
Aug 18 Python
django inspectdb 操作已有数据库数据的使用步骤
Feb 07 Python
python 爬取腾讯视频评论的实现步骤
Feb 18 Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
You might like
Mysql数据库操作类( 1127版,提供源码下载 )
2010/12/02 PHP
Yii数据库缓存实例分析
2016/03/29 PHP
php使用curl代理实现抓取数据的方法
2017/02/03 PHP
php微信开发之音乐回复功能
2018/06/14 PHP
提高代码性能技巧谈—以创建千行表格为例
2006/07/01 Javascript
你未必知道的JavaScript和CSS交互的5种方法
2014/04/02 Javascript
JavaScript获取路径设计源码
2014/05/22 Javascript
原生js实现的贪吃蛇网页版游戏完整实例
2015/05/18 Javascript
JavaScript实现的类字典插入或更新方法实例
2015/07/10 Javascript
AngularJS入门教程之ng-class 指令用法
2016/08/01 Javascript
深入理解javascript中的 “this”
2017/01/17 Javascript
Angular.js中控制器之间的传值详解
2017/04/24 Javascript
详解nodejs express下使用redis管理session
2017/04/24 NodeJs
JavaScript仿微信(电话)联系人列表滑动字母索引实例讲解(推荐)
2017/08/16 Javascript
微信小程序之GET请求的实例详解
2017/09/29 Javascript
javaScript实现鼠标在文字上悬浮时弹出悬浮层效果
2020/04/12 Javascript
vue中$refs的用法及作用详解
2018/04/24 Javascript
Chart.js 轻量级HTML5图表绘制工具库(知识整理)
2018/05/22 Javascript
原生JS实现轮播图效果
2018/10/12 Javascript
VUE2.0+ElementUI2.0表格el-table实现表头扩展el-tooltip
2018/11/30 Javascript
详解小程序云开发数据库
2019/05/20 Javascript
[03:41]DOTA2上海特锦赛小组赛第三日recap精彩回顾
2016/02/28 DOTA
使用PyCharm配合部署Python的Django框架的配置纪实
2015/11/19 Python
python实现聚类算法原理
2018/02/12 Python
pyqt5的QComboBox 使用模板的具体方法
2018/09/06 Python
不管你的Python报什么错,用这个模块就能正常运行
2018/09/14 Python
给Python学习者的文件读写指南(含基础与进阶)
2020/01/29 Python
员工培训邀请函
2014/02/02 职场文书
公司员工检讨书
2014/02/08 职场文书
护士自我鉴定总结
2014/03/24 职场文书
鲁迅故里导游词
2015/02/05 职场文书
第一军规观后感
2015/06/12 职场文书
回复函范文
2015/07/14 职场文书
FP-growth算法发现频繁项集——构建FP树
2021/06/24 Python
【TED出品】天梯非主流开心游1700 划水骑士
2022/03/31 魔兽争霸
Elasticsearch 批量操作
2022/04/19 Python