TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python tempfile模块学习笔记(临时文件)
May 25 Python
python黑魔法之参数传递
Feb 12 Python
Python找出list中最常出现元素的方法
Jun 14 Python
Django使用httpresponse返回用户头像实例代码
Jan 26 Python
python实现屏保计时器的示例代码
Aug 08 Python
python selenium执行所有测试用例并生成报告的方法
Feb 13 Python
Python2和3字符编码的区别知识点整理
Aug 08 Python
Django admin.py 在修改/添加表单界面显示额外字段的方法
Aug 22 Python
python实现计算器功能
Oct 31 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 Python
Python3 集合set入门基础
Feb 10 Python
Python 输出详细的异常信息(traceback)方式
Apr 08 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
简单的php写入数据库类代码分享
2011/07/26 PHP
Laravel 5 框架入门(三)
2015/04/09 PHP
AJAX PHP无刷新form表单提交的简单实现(推荐)
2016/09/09 PHP
PHP中SERIALIZE和JSON的序列化与反序列化操作区别分析
2016/10/11 PHP
PHP实现的Redis多库选择功能单例类
2017/07/27 PHP
利用PHP计算有多少小于当前数字的数字方法示例
2020/08/26 PHP
禁止刷新,回退的JS
2006/11/25 Javascript
基于jquery的横向滚动条(滑动条)
2011/02/24 Javascript
JavaScript判断一个URL链接是否有效的实现方法
2011/10/08 Javascript
indexOf 和 lastIndexOf 使用示例介绍
2014/09/02 Javascript
2014年50个程序员最适用的免费JQuery插件
2014/12/15 Javascript
举例讲解AngularJS中的模块
2015/06/17 Javascript
js实现浏览本地文件并显示扩展名的方法
2015/08/17 Javascript
分享使用AngularJS创建应用的5个框架
2015/12/05 Javascript
Bootstrap轮播加上css3动画,炫酷到底!
2015/12/22 Javascript
Ionic快速安装教程
2016/06/03 Javascript
js 性能优化之算法和流程控制
2017/02/15 Javascript
详解使用React全家桶搭建一个后台管理系统
2017/11/04 Javascript
JavaScript时间日期操作实例小结【5个示例】
2018/12/22 Javascript
大转盘抽奖小程序版 转盘抽奖网页版
2020/04/16 Javascript
Vue路由守卫之路由独享守卫
2019/09/25 Javascript
Node绑定全局TraceID的实现方法
2019/11/14 Javascript
JS 图片压缩原理与实现方法详解
2020/04/29 Javascript
vue监听浏览器原生返回按钮,进行路由转跳操作
2020/09/09 Javascript
Python的Django框架可适配的各种数据库介绍
2015/07/15 Python
Python实现求笛卡尔乘积的方法
2017/09/16 Python
Python脚本按照当前日期创建多级目录
2019/03/01 Python
pytorch 在sequential中使用view来reshape的例子
2019/08/20 Python
python pip安装包出现:Failed building wheel for xxx错误的解决
2019/12/25 Python
解决Python Matplotlib绘图数据点位置错乱问题
2020/05/16 Python
CSS3悬停效果案例应用
2012/11/21 HTML / CSS
西班牙电子产品购物网站:Electronicamente
2018/07/26 全球购物
红色连衣裙精品店:Red Dress Boutique
2018/08/11 全球购物
高中生学习生活的自我评价
2013/11/27 职场文书
教师党员思想汇报
2014/01/06 职场文书
颐和园导游词
2015/01/30 职场文书