TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中os操作文件及文件路径实例汇总
Jan 15 Python
Python3读取UTF-8文件及统计文件行数的方法
May 22 Python
Django小白教程之Django用户注册与登录
Apr 22 Python
利用Python画ROC曲线和AUC值计算
Sep 19 Python
Python与Java间Socket通信实例代码
Mar 06 Python
解决Pycharm界面的子窗口不见了的问题
Jan 17 Python
Django REST framework 视图和路由详解
Jul 19 Python
Python pip配置国内源的方法
Feb 14 Python
Pandas读取csv时如何设置列名
Jun 02 Python
Python面向对象多态实现原理及代码实例
Sep 16 Python
python 使用OpenCV进行简单的人像分割与合成
Feb 02 Python
教你用Python+selenium搭建自动化测试环境
Jun 18 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
header()函数使用说明
2006/11/23 PHP
缓存技术详谈―php
2006/12/14 PHP
Zend Studio (eclipse)使用速度优化方法
2011/03/23 PHP
基于PHP实现短信验证码接口(容联运通讯)
2016/09/06 PHP
PHP+JS实现的商品秒杀倒计时用法示例
2016/11/15 PHP
laravel学习笔记之模型事件的几种用法示例
2017/08/15 PHP
Smarty模板类内部原理实例分析
2019/07/03 PHP
thinkphp5.1框架实现格式化mysql时间戳为日期的方式小结
2019/10/10 PHP
表单的一些基本用法与技巧
2006/07/15 Javascript
非常不错的一个javascript 类
2006/11/07 Javascript
Pro JavaScript Techniques学习笔记
2010/12/28 Javascript
JQuery动态创建DOM、表单元素的实现代码
2011/08/09 Javascript
Js+Jq获取URL参数的集中方法示例代码
2014/05/20 Javascript
谈谈jQuery Ajax用法详解
2015/11/27 Javascript
vue.js 1.x与2.0中js实时监听input值的变化
2017/03/15 Javascript
AjaxFileUpload.js实现异步上传文件功能
2019/04/19 Javascript
centos系统升级python 2.7.3
2014/07/03 Python
Python程序员开发中常犯的10个错误
2014/07/07 Python
python实现批量视频分帧、保存视频帧
2019/05/31 Python
Python字典对象实现原理详解
2019/07/01 Python
详解CSS中iconfont的使用
2015/08/04 HTML / CSS
css3 clip实现圆环进度条的示例代码
2018/02/07 HTML / CSS
梅西百货澳大利亚:Macy’s Australia
2017/07/26 全球购物
欧洲最大的滑雪假期供应商之一:Sunweb Holidays
2018/01/06 全球购物
英国领先的野生鸟类食品供应商:GardenBird
2018/08/09 全球购物
如何查询Oracle数据库中已经创建的索引
2013/10/11 面试题
公司出纳岗位职责
2013/12/07 职场文书
党员领导干部承诺书
2014/05/28 职场文书
技术股东合作协议书
2014/12/02 职场文书
青岛导游词
2015/02/12 职场文书
统计工作个人总结
2015/03/03 职场文书
2015年酒店服务员工作总结
2015/05/18 职场文书
农村结婚典礼主持词
2015/06/29 职场文书
幼儿教师师德培训心得体会
2016/01/09 职场文书
2016年六一儿童节开幕词
2016/03/04 职场文书
Python实现滑雪小游戏
2021/09/25 Python