TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
全面解读Python Web开发框架Django
Jun 30 Python
Python实现的二维码生成小软件
Jul 11 Python
python:socket传输大文件示例
Jan 18 Python
网站渗透常用Python小脚本查询同ip网站
May 08 Python
修改python plot折线图的坐标轴刻度方法
Dec 13 Python
python3.6数独问题的解决
Jan 21 Python
Python 在OpenCV里实现仿射变换—坐标变换效果
Aug 30 Python
基于Tensorflow使用CPU而不用GPU问题的解决
Feb 07 Python
Pandas —— resample()重采样和asfreq()频度转换方式
Feb 26 Python
python在CMD界面读取excel所有数据的示例
Sep 28 Python
python 发送get请求接口详解
Nov 17 Python
Python Http请求json解析库用法解析
Nov 28 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
PHP扩展编写点滴 技巧收集
2010/03/09 PHP
php excel类 phpExcel使用方法介绍
2010/08/21 PHP
PHP中include与require使用方法区别详解
2013/10/19 PHP
Yii框架常见缓存应用实例小结
2019/09/09 PHP
Array.slice()与Array.splice()的返回值类型
2006/10/09 Javascript
JavaScript作用域链示例分享
2014/05/27 Javascript
网页实时显示服务器时间和javscript自运行时钟
2014/06/09 Javascript
一个JavaScript处理textarea中的字符成每一行实例
2014/09/22 Javascript
angularjs表格ng-table使用备忘录
2016/03/09 Javascript
基于javascript显示当前时间以及倒计时功能
2016/03/18 Javascript
Bootstrap CDN和本地化环境搭建
2016/10/26 Javascript
MUI 上拉刷新/下拉加载功能实例代码
2017/04/13 Javascript
easyUI下拉列表点击事件使用方法
2017/05/18 Javascript
Jquery EasyUI $.Parser
2017/06/02 jQuery
Vue中的ref作用详解(实现DOM的联动操作)
2017/08/21 Javascript
vuex的使用及持久化state的方式详解
2018/01/23 Javascript
Vue子组件向父组件通信与父组件调用子组件中的方法
2018/06/22 Javascript
详解vue-cli下ESlint 配置说明
2018/09/03 Javascript
python使用WMI检测windows系统信息、硬盘信息、网卡信息的方法
2015/05/15 Python
Python中模块string.py详解
2017/03/12 Python
Python交互环境下实现输入代码
2018/06/22 Python
解决pycharm每次新建项目都要重新安装一些第三方库的问题
2019/01/17 Python
浅谈python3.6的tkinter运行问题
2019/02/22 Python
详解Django-restframework 之频率源码分析
2019/02/27 Python
python 使用socket传输图片视频等文件的实现方式
2019/08/07 Python
Python 开发工具PyCharm安装教程图文详解(新手必看)
2020/02/28 Python
办公自动化毕业生求职信
2014/03/09 职场文书
揭牌仪式策划方案
2014/05/28 职场文书
应届生自荐书
2014/06/23 职场文书
党员一帮一活动总结
2014/07/08 职场文书
新课培训心得体会
2014/09/03 职场文书
某集团股份有限公司委托书样本
2014/09/24 职场文书
镇人大副主席民主生活会对照检查材料思想汇报
2014/10/01 职场文书
2014年小学数学工作总结
2014/12/12 职场文书
微信告警的zabbix监控系统 监控整个NGINX集群
2022/04/18 Servers
Apache SeaTunnel实现 非CDC数据抽取
2022/05/20 Servers