TensorFlow实现随机训练和批量训练的方法


Posted in Python onApril 28, 2018

TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。

为了TensorFlow计算变量梯度来让反向传播工作,我们必须度量一个或者多个样本的损失。

随机训练会一次随机抽样训练数据和目标数据对完成训练。另外一个可选项是,一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。这里将显示如何扩展前面的回归算法的例子——使用随机训练和批量训练。

批量训练和随机训练的不同之处在于它们的优化器方法和收敛。

# 随机训练和批量训练
#----------------------------------
#
# This python function illustrates two different training methods:
# batch and stochastic training. For each model, we will use
# a regression model that predicts one model variable.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 随机训练:
# Create graph
sess = tf.Session()

# 声明数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加操作到图
my_output = tf.multiply(x_data, A)

# 增加L2损失函数
loss = tf.square(my_output - y_target)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_stochastic = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100)
 rand_x = [x_vals[rand_index]]
 rand_y = [y_vals[rand_index]]
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_stochastic.append(temp_loss)


# 批量训练:
# 重置计算图
ops.reset_default_graph()
sess = tf.Session()

# 声明批量大小
# 批量大小是指通过计算图一次传入多少训练数据
batch_size = 20

# 声明模型的数据、占位符
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加矩阵乘法操作(矩阵乘法不满足交换律)
my_output = tf.matmul(x_data, A)

# 增加损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_batch = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100, size=batch_size)
 rand_x = np.transpose([x_vals[rand_index]])
 rand_y = np.transpose([y_vals[rand_index]])
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_batch.append(temp_loss)

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出:

Step #5 A = [ 1.47604525]
Loss = [ 72.55678558]
Step #10 A = [ 3.01128507]
Loss = [ 48.22986221]
Step #15 A = [ 4.27042341]
Loss = [ 28.97912598]
Step #20 A = [ 5.2984333]
Loss = [ 16.44779968]
Step #25 A = [ 6.17473984]
Loss = [ 16.373312]
Step #30 A = [ 6.89866304]
Loss = [ 11.71054649]
Step #35 A = [ 7.39849901]
Loss = [ 6.42773056]
Step #40 A = [ 7.84618378]
Loss = [ 5.92940331]
Step #45 A = [ 8.15709782]
Loss = [ 0.2142024]
Step #50 A = [ 8.54818344]
Loss = [ 7.11651039]
Step #55 A = [ 8.82354641]
Loss = [ 1.47823763]
Step #60 A = [ 9.07896614]
Loss = [ 3.08244276]
Step #65 A = [ 9.24868107]
Loss = [ 0.01143846]
Step #70 A = [ 9.36772251]
Loss = [ 2.10078788]
Step #75 A = [ 9.49171734]
Loss = [ 3.90913701]
Step #80 A = [ 9.6622715]
Loss = [ 4.80727625]
Step #85 A = [ 9.73786926]
Loss = [ 0.39915398]
Step #90 A = [ 9.81853104]
Loss = [ 0.14876099]
Step #95 A = [ 9.90371323]
Loss = [ 0.01657014]
Step #100 A = [ 9.86669159]
Loss = [ 0.444787]
Step #5 A = [[ 2.34371352]]
Loss = 58.766
Step #10 A = [[ 3.74766445]]
Loss = 38.4875
Step #15 A = [[ 4.88928795]]
Loss = 27.5632
Step #20 A = [[ 5.82038736]]
Loss = 17.9523
Step #25 A = [[ 6.58999157]]
Loss = 13.3245
Step #30 A = [[ 7.20851326]]
Loss = 8.68099
Step #35 A = [[ 7.71694899]]
Loss = 4.60659
Step #40 A = [[ 8.1296711]]
Loss = 4.70107
Step #45 A = [[ 8.47107315]]
Loss = 3.28318
Step #50 A = [[ 8.74283409]]
Loss = 1.99057
Step #55 A = [[ 8.98811722]]
Loss = 2.66906
Step #60 A = [[ 9.18062305]]
Loss = 3.26207
Step #65 A = [[ 9.31655025]]
Loss = 2.55459
Step #70 A = [[ 9.43130589]]
Loss = 1.95839
Step #75 A = [[ 9.55670166]]
Loss = 1.46504
Step #80 A = [[ 9.6354847]]
Loss = 1.49021
Step #85 A = [[ 9.73470974]]
Loss = 1.53289
Step #90 A = [[ 9.77956581]]
Loss = 1.52173
Step #95 A = [[ 9.83666706]]
Loss = 0.819207
Step #100 A = [[ 9.85569191]]
Loss = 1.2197

TensorFlow实现随机训练和批量训练的方法

训练类型 优点 缺点
随机训练 脱离局部最小 一般需更多次迭代才收敛
批量训练 快速得到最小损失 耗费更多计算资源

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python代码写的12306订票代码
Dec 20 Python
python爬虫_微信公众号推送信息爬取的实例
Oct 23 Python
Python从ZabbixAPI获取信息及实现Zabbix-API 监控的方法
Sep 17 Python
解决pip install xxx报错SyntaxError: invalid syntax的问题
Nov 30 Python
pycharm中显示CSS提示的知识点总结
Jul 29 Python
基于python修改srt字幕的时间轴
Feb 03 Python
Django 实现将图片转为Base64,然后使用json传输
Mar 27 Python
完美解决jupyter由于无法import新包的问题
May 26 Python
python爬虫实例之获取动漫截图
May 31 Python
使用pycharm和pylint检查python代码规范操作
Jun 09 Python
Python3利用openpyxl读写Excel文件的方法实例
Feb 03 Python
Django debug为True时,css加载失败的解决方案
Apr 24 Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
You might like
PHP parse_url 一个好用的函数
2009/10/03 PHP
怎样给PHP源代码加密?PHP二进制加密与解密的解决办法
2013/04/22 PHP
php实现批量下载百度云盘文件例子分享
2014/04/10 PHP
ThinkPHP实现ajax仿官网搜索功能实例
2014/12/02 PHP
PHP简单读取xml文件的方法示例
2017/04/20 PHP
PHP仿tp实现mvc框架基本设计思路与实现方法分析
2018/05/23 PHP
PHP扩展mcrypt实现的AES加密功能示例
2019/01/29 PHP
PHP上传图片到数据库并显示的实例代码
2019/12/20 PHP
jquery实现div拖拽宽度示例代码
2013/07/31 Javascript
jquerydom对象的事件隐藏显示和对象数组示例
2013/12/10 Javascript
Javascript selection的兼容性写法介绍
2013/12/20 Javascript
jquery模拟LCD 时钟的html文件源代码
2014/06/16 Javascript
jQuery中DOM树操作之复制元素的方法
2015/01/23 Javascript
JavaScript判断手机号运营商是移动、联通、电信还是其他(代码简单)
2015/09/25 Javascript
js判断文本框输入的内容是否为数字
2015/12/23 Javascript
jQuery之简单的表单验证实例
2016/07/07 Javascript
jsonp跨域请求实现示例
2017/03/13 Javascript
Vue组件通信之Bus的具体使用
2017/12/28 Javascript
ES6之模版字符串的具体使用
2018/05/17 Javascript
详解Vue+ElementUI从零开始搭建自己的网站(一、环境搭建)
2019/04/30 Javascript
JavaScript多种图形实现代码实例
2020/06/28 Javascript
[18:20]DOTA2 HEROS教学视频教你分分钟做大人-昆卡
2014/06/11 DOTA
Python编程中使用Pillow来处理图像的基础教程
2015/11/20 Python
python实现多线程的两种方式
2016/05/22 Python
python 实现让字典的value 成为列表
2019/12/16 Python
python读取tif图片时保留其16bit的编码格式实例
2020/01/13 Python
如何使用Django Admin管理后台导入CSV
2020/11/06 Python
突袭HTML5之Javascript API扩展2—地理信息服务及地理位置API学习
2013/01/31 HTML / CSS
LACOSTE波兰官网:Polo衫、服装和鞋类
2020/09/29 全球购物
爱护公共设施的标语
2014/06/24 职场文书
2014年党员个人剖析材料
2014/10/08 职场文书
2014年评职称工作总结
2014/11/20 职场文书
小学重阳节活动总结
2015/03/24 职场文书
2015年统计员个人工作总结
2015/07/23 职场文书
2019朋友新婚祝福语精选
2019/10/10 职场文书
python字典进行运算原理及实例分享
2021/08/02 Python