TensorFlow实现随机训练和批量训练的方法


Posted in Python onApril 28, 2018

TensorFlow更新模型变量。它能一次操作一个数据点,也可以一次操作大量数据。一个训练例子上的操作可能导致比较“古怪”的学习过程,但使用大批量的训练会造成计算成本昂贵。到底选用哪种训练类型对机器学习算法的收敛非常关键。

为了TensorFlow计算变量梯度来让反向传播工作,我们必须度量一个或者多个样本的损失。

随机训练会一次随机抽样训练数据和目标数据对完成训练。另外一个可选项是,一次大批量训练取平均损失来进行梯度计算,批量训练大小可以一次上扩到整个数据集。这里将显示如何扩展前面的回归算法的例子——使用随机训练和批量训练。

批量训练和随机训练的不同之处在于它们的优化器方法和收敛。

# 随机训练和批量训练
#----------------------------------
#
# This python function illustrates two different training methods:
# batch and stochastic training. For each model, we will use
# a regression model that predicts one model variable.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 随机训练:
# Create graph
sess = tf.Session()

# 声明数据
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加操作到图
my_output = tf.multiply(x_data, A)

# 增加L2损失函数
loss = tf.square(my_output - y_target)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_stochastic = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100)
 rand_x = [x_vals[rand_index]]
 rand_y = [y_vals[rand_index]]
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_stochastic.append(temp_loss)


# 批量训练:
# 重置计算图
ops.reset_default_graph()
sess = tf.Session()

# 声明批量大小
# 批量大小是指通过计算图一次传入多少训练数据
batch_size = 20

# 声明模型的数据、占位符
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明变量 (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加矩阵乘法操作(矩阵乘法不满足交换律)
my_output = tf.matmul(x_data, A)

# 增加损失函数
# 批量训练时损失函数是每个数据点L2损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

loss_batch = []
# 运行迭代
for i in range(100):
 rand_index = np.random.choice(100, size=batch_size)
 rand_x = np.transpose([x_vals[rand_index]])
 rand_y = np.transpose([y_vals[rand_index]])
 sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
 if (i+1)%5==0:
  print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  print('Loss = ' + str(temp_loss))
  loss_batch.append(temp_loss)

plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size=20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()

输出:

Step #5 A = [ 1.47604525]
Loss = [ 72.55678558]
Step #10 A = [ 3.01128507]
Loss = [ 48.22986221]
Step #15 A = [ 4.27042341]
Loss = [ 28.97912598]
Step #20 A = [ 5.2984333]
Loss = [ 16.44779968]
Step #25 A = [ 6.17473984]
Loss = [ 16.373312]
Step #30 A = [ 6.89866304]
Loss = [ 11.71054649]
Step #35 A = [ 7.39849901]
Loss = [ 6.42773056]
Step #40 A = [ 7.84618378]
Loss = [ 5.92940331]
Step #45 A = [ 8.15709782]
Loss = [ 0.2142024]
Step #50 A = [ 8.54818344]
Loss = [ 7.11651039]
Step #55 A = [ 8.82354641]
Loss = [ 1.47823763]
Step #60 A = [ 9.07896614]
Loss = [ 3.08244276]
Step #65 A = [ 9.24868107]
Loss = [ 0.01143846]
Step #70 A = [ 9.36772251]
Loss = [ 2.10078788]
Step #75 A = [ 9.49171734]
Loss = [ 3.90913701]
Step #80 A = [ 9.6622715]
Loss = [ 4.80727625]
Step #85 A = [ 9.73786926]
Loss = [ 0.39915398]
Step #90 A = [ 9.81853104]
Loss = [ 0.14876099]
Step #95 A = [ 9.90371323]
Loss = [ 0.01657014]
Step #100 A = [ 9.86669159]
Loss = [ 0.444787]
Step #5 A = [[ 2.34371352]]
Loss = 58.766
Step #10 A = [[ 3.74766445]]
Loss = 38.4875
Step #15 A = [[ 4.88928795]]
Loss = 27.5632
Step #20 A = [[ 5.82038736]]
Loss = 17.9523
Step #25 A = [[ 6.58999157]]
Loss = 13.3245
Step #30 A = [[ 7.20851326]]
Loss = 8.68099
Step #35 A = [[ 7.71694899]]
Loss = 4.60659
Step #40 A = [[ 8.1296711]]
Loss = 4.70107
Step #45 A = [[ 8.47107315]]
Loss = 3.28318
Step #50 A = [[ 8.74283409]]
Loss = 1.99057
Step #55 A = [[ 8.98811722]]
Loss = 2.66906
Step #60 A = [[ 9.18062305]]
Loss = 3.26207
Step #65 A = [[ 9.31655025]]
Loss = 2.55459
Step #70 A = [[ 9.43130589]]
Loss = 1.95839
Step #75 A = [[ 9.55670166]]
Loss = 1.46504
Step #80 A = [[ 9.6354847]]
Loss = 1.49021
Step #85 A = [[ 9.73470974]]
Loss = 1.53289
Step #90 A = [[ 9.77956581]]
Loss = 1.52173
Step #95 A = [[ 9.83666706]]
Loss = 0.819207
Step #100 A = [[ 9.85569191]]
Loss = 1.2197

TensorFlow实现随机训练和批量训练的方法

训练类型 优点 缺点
随机训练 脱离局部最小 一般需更多次迭代才收敛
批量训练 快速得到最小损失 耗费更多计算资源

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量生成本地ip地址的方法
Mar 23 Python
Python引用模块和查找模块路径
Mar 17 Python
Python备份目录及目录下的全部内容的实现方法
Jun 12 Python
python实现批量监控网站
Sep 09 Python
python之Socket网络编程详解
Sep 29 Python
win系统下为Python3.5安装flask-mongoengine 库
Dec 20 Python
用python实现对比两张图片的不同
Feb 05 Python
Django xadmin开启搜索功能的实现
Nov 15 Python
Python: tkinter窗口屏幕居中,设置窗口最大,最小尺寸实例
Mar 04 Python
Django User 模块之 AbstractUser 扩展详解
Mar 11 Python
Python startswith()和endswith() 方法原理解析
Apr 28 Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
Python使用matplotlib实现的图像读取、切割裁剪功能示例
Apr 28 #Python
浅谈python日志的配置文件路径问题
Apr 28 #Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 #Python
python 日志增量抓取实现方法
Apr 28 #Python
Django 使用logging打印日志的实例
Apr 28 #Python
You might like
php打印一个边长为N的实心和空心菱型的方法
2015/03/02 PHP
php将print_r处理后的数据还原为原始数组的解决方法
2016/11/02 PHP
PHP二维数组去重实例分析
2016/11/18 PHP
PhpStorm本地断点调试的方法步骤
2018/05/21 PHP
浅谈laravel5.5 belongsToMany自身的正确用法
2019/10/17 PHP
js中方法重载如何实现?以及函数的参数问题
2013/08/01 Javascript
浅谈javascript中createElement事件
2014/12/05 Javascript
js实现动画特效的文字链接鼠标悬停提示的方法
2015/03/02 Javascript
详解JavaScript中的异常处理方法
2015/06/16 Javascript
实例详解jQuery表单验证插件validate
2016/01/18 Javascript
bootstrap3 兼容IE8浏览器!
2016/05/02 Javascript
js的form表单提交url传参数(包含+等特殊字符)的两种解决方法
2016/05/25 Javascript
JS工作中的小贴士之”闭包“与事件委托的”阻止冒泡“
2016/06/16 Javascript
jquery获取table指定行和列的数据方法(当前选中行、列)
2016/11/07 Javascript
Webpack实现按需打包Lodash的几种方法详解
2017/05/08 Javascript
React 路由懒加载的几种实现方案
2018/10/23 Javascript
微信小程序webview实现长按点击识别二维码功能示例
2019/01/24 Javascript
JavaScript创建、读取和删除cookie
2019/09/03 Javascript
JS实现4位随机验证码
2020/10/19 Javascript
Vue 实现拨打电话操作
2020/11/16 Javascript
[12:21]VICI vs TNC (BO3)
2018/06/07 DOTA
深入理解Python 关于supper 的 用法和原理
2018/02/28 Python
解决Pycharm的项目目录突然消失的问题
2020/01/20 Python
在Python IDLE 下调用anaconda中的库教程
2020/03/09 Python
python3:excel操作之读取数据并返回字典 + 写入的案例
2020/09/01 Python
浅谈利用缓存来优化HTML5 Canvas程序的性能
2015/05/12 HTML / CSS
舒适的豪华鞋:Taryn Rose
2018/05/03 全球购物
美国领先的在线旅游网站:Orbitz
2018/11/05 全球购物
说说在weblogic中开发消息Bean时的persistent与non-persisten的差别
2013/04/07 面试题
区三好学生主要事迹
2014/01/30 职场文书
测绘专业大学生职业生涯规划书
2014/02/10 职场文书
成龙洗发水广告词
2014/03/14 职场文书
2014小学教师年度考核工作总结
2014/12/03 职场文书
销售员自我评价
2015/03/11 职场文书
Redis调用Lua脚本及使用场景快速掌握
2022/03/16 Redis
MySQL数据库Innodb 引擎实现mvcc锁
2022/05/06 MySQL