TensorFlow实现随机训练和批量训练的方法


Posted in Python onApril 28, 2018

TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。

为了TensorFlow计算变量梯度来让反向传播工作,我们必须度量一个或者多个样本的损失。

随机训练会一次随机抽样训练数据和目标数据对完成训练。另外一个可选项是,一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。这里将显示如何扩展前面的回归算法的例子——使用随机训练和批量训练。

批量训练和随机训练的不同之处在于它们的优化器方法和收敛。

# 随机训练和批量训练
#----------------------------------
#
# This python function illustrates two different training methods:
# batch and stochastic training. For each model, we will use
# a regression model that predicts one model variable.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 随机训练:
# Create graph
sess = tf.Session()

# 声明数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加操作到图
my_output = tf.multiply(x_data, A)

# 增加L2损失函数
loss = tf.square(my_output - y_target)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_stochastic = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100)
 rand_x = [x_vals[rand_index]]
 rand_y = [y_vals[rand_index]]
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_stochastic.append(temp_loss)


# 批量训练:
# 重置计算图
ops.reset_default_graph()
sess = tf.Session()

# 声明批量大小
# 批量大小是指通过计算图一次传入多少训练数据
batch_size = 20

# 声明模型的数据、占位符
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加矩阵乘法操作(矩阵乘法不满足交换律)
my_output = tf.matmul(x_data, A)

# 增加损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_batch = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100, size=batch_size)
 rand_x = np.transpose([x_vals[rand_index]])
 rand_y = np.transpose([y_vals[rand_index]])
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_batch.append(temp_loss)

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出:

Step #5 A = [ 1.47604525]
Loss = [ 72.55678558]
Step #10 A = [ 3.01128507]
Loss = [ 48.22986221]
Step #15 A = [ 4.27042341]
Loss = [ 28.97912598]
Step #20 A = [ 5.2984333]
Loss = [ 16.44779968]
Step #25 A = [ 6.17473984]
Loss = [ 16.373312]
Step #30 A = [ 6.89866304]
Loss = [ 11.71054649]
Step #35 A = [ 7.39849901]
Loss = [ 6.42773056]
Step #40 A = [ 7.84618378]
Loss = [ 5.92940331]
Step #45 A = [ 8.15709782]
Loss = [ 0.2142024]
Step #50 A = [ 8.54818344]
Loss = [ 7.11651039]
Step #55 A = [ 8.82354641]
Loss = [ 1.47823763]
Step #60 A = [ 9.07896614]
Loss = [ 3.08244276]
Step #65 A = [ 9.24868107]
Loss = [ 0.01143846]
Step #70 A = [ 9.36772251]
Loss = [ 2.10078788]
Step #75 A = [ 9.49171734]
Loss = [ 3.90913701]
Step #80 A = [ 9.6622715]
Loss = [ 4.80727625]
Step #85 A = [ 9.73786926]
Loss = [ 0.39915398]
Step #90 A = [ 9.81853104]
Loss = [ 0.14876099]
Step #95 A = [ 9.90371323]
Loss = [ 0.01657014]
Step #100 A = [ 9.86669159]
Loss = [ 0.444787]
Step #5 A = [[ 2.34371352]]
Loss = 58.766
Step #10 A = [[ 3.74766445]]
Loss = 38.4875
Step #15 A = [[ 4.88928795]]
Loss = 27.5632
Step #20 A = [[ 5.82038736]]
Loss = 17.9523
Step #25 A = [[ 6.58999157]]
Loss = 13.3245
Step #30 A = [[ 7.20851326]]
Loss = 8.68099
Step #35 A = [[ 7.71694899]]
Loss = 4.60659
Step #40 A = [[ 8.1296711]]
Loss = 4.70107
Step #45 A = [[ 8.47107315]]
Loss = 3.28318
Step #50 A = [[ 8.74283409]]
Loss = 1.99057
Step #55 A = [[ 8.98811722]]
Loss = 2.66906
Step #60 A = [[ 9.18062305]]
Loss = 3.26207
Step #65 A = [[ 9.31655025]]
Loss = 2.55459
Step #70 A = [[ 9.43130589]]
Loss = 1.95839
Step #75 A = [[ 9.55670166]]
Loss = 1.46504
Step #80 A = [[ 9.6354847]]
Loss = 1.49021
Step #85 A = [[ 9.73470974]]
Loss = 1.53289
Step #90 A = [[ 9.77956581]]
Loss = 1.52173
Step #95 A = [[ 9.83666706]]
Loss = 0.819207
Step #100 A = [[ 9.85569191]]
Loss = 1.2197

TensorFlow实现随机训练和批量训练的方法

训练类型 优点 缺点
随机训练 脱离局部最小 一般需更多次迭代才收敛
批量训练 快速得到最小损失 耗费更多计算资源

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作gmail实例
Jan 14 Python
Python中处理字符串之isalpha()方法的使用
May 18 Python
python 安装virtualenv和virtualenvwrapper的方法
Jan 13 Python
使用anaconda的pip安装第三方python包的操作步骤
Jun 11 Python
Python实现定时自动关闭的tkinter窗口方法
Feb 16 Python
python实现贪吃蛇游戏源码
Mar 21 Python
django admin 根据choice字段选择的不同来显示不同的页面方式
May 13 Python
python能自学吗
Jun 18 Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 Python
Python通过zookeeper实现分布式服务代码解析
Jul 22 Python
pytorch 中nn.Dropout的使用说明
May 20 Python
python实现MD5进行文件去重的示例代码
Jul 09 Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
You might like
php的chr和ord函数实现字符加减乘除运算实现代码
2011/12/05 PHP
php ajax异步读取rss文档数据
2016/03/29 PHP
[原创]php实现子字符串位置相互对调互换的方法
2016/06/02 PHP
PHP微信开发之根据用户回复关键词\位置返回附近信息
2016/06/24 PHP
PHP中SERIALIZE和JSON的序列化与反序列化操作区别分析
2016/10/11 PHP
总结PHP内存释放以及垃圾回收
2018/03/29 PHP
jquery autocomplete自动完成插件的的使用方法
2010/08/07 Javascript
JQERY limittext 插件0.2版(长内容限制显示)
2010/08/27 Javascript
angularjs在ng-repeat中使用ng-model遇到的问题
2016/01/21 Javascript
jQuery无刷新上传之uploadify3.1简单使用
2016/06/18 Javascript
jQuery弹出下拉列表插件(实现kindeditor的@功能)
2016/08/16 Javascript
Vue2.0 从零开始_环境搭建操作步骤
2017/06/14 Javascript
Vue 中使用 CSS Modules优雅方法
2018/04/09 Javascript
在小程序中使用Echart图表的示例代码
2018/08/02 Javascript
vue-router启用history模式下的开发及非根目录部署方法
2018/12/23 Javascript
微信小程序录音实现功能并上传(使用node解析接收)
2020/02/26 Javascript
JavaScript随机数的组合问题案例分析
2020/05/16 Javascript
jQuery实现倒计时功能完整示例
2020/06/01 jQuery
python判断给定的字符串是否是有效日期的方法
2015/05/13 Python
Python中使用strip()方法删除字符串中空格的教程
2015/05/20 Python
Python图像处理之简单画板实现方法示例
2018/08/30 Python
python保存二维数组到txt文件中的方法
2018/11/15 Python
Django关于admin的使用技巧和知识点
2020/02/10 Python
matplotlib之多边形选区(PolygonSelector)的使用
2021/02/24 Python
环法自行车赛官方商店:Le Tour de France
2017/08/27 全球购物
德国2018年度最佳在线药房:Bodfeld Apotheke
2019/11/04 全球购物
参观考察邀请函范文
2014/01/29 职场文书
全国优秀辅导员事迹材料
2014/05/14 职场文书
产品包装策划方案
2014/05/18 职场文书
群众路线教育实践活动批评与自我批评
2014/09/15 职场文书
2015年大学生党员承诺书
2015/04/27 职场文书
复兴之路展览观后感
2015/06/02 职场文书
2015年卫生院健康教育工作总结
2015/07/24 职场文书
高中体育课教学反思
2016/02/16 职场文书
Win11怎样将锁屏账户头像图片改成动画视频
2021/11/21 数码科技
Go gorilla/sessions库安装使用
2022/08/14 Golang