TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python全局变量用法实例分析
Jul 19 Python
利用Python爬取微博数据生成词云图片实例代码
Aug 31 Python
python3安装pip3(install pip3 for python 3.x)
Apr 03 Python
Python实现字典(dict)的迭代操作示例
Jun 05 Python
Python查看微信撤回消息代码
Jun 07 Python
Python实现的简单计算器功能详解
Aug 25 Python
Python实现的逻辑回归算法示例【附测试csv文件下载】
Dec 28 Python
python用类实现文章敏感词的过滤方法示例
Oct 27 Python
Python Numpy中数据的常用保存与读取方法
Apr 01 Python
使用pycharm和pylint检查python代码规范操作
Jun 09 Python
python实现梯度下降算法的实例详解
Aug 17 Python
python 两种方法修改文件的创建时间、修改时间、访问时间
Sep 26 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
php 大数据量及海量数据处理算法总结
2011/05/07 PHP
php开发过程中关于继承的使用方法分享
2011/06/17 PHP
php 获取百度的热词数据的代码
2012/02/18 PHP
PHP解密Unicode及Escape加密字符串
2015/05/17 PHP
[原创]php正则删除img标签的方法示例
2017/05/27 PHP
PHP实现创建微信自定义菜单的方法示例
2017/07/14 PHP
Yii框架模拟组件调用注入示例
2019/11/11 PHP
javascript radio 联动效果
2009/03/04 Javascript
Mootools 1.2教程 滚动条(Slider)
2009/09/15 Javascript
为jquery.ui.dialog 增加“在当前鼠标位置打开”的功能
2009/11/24 Javascript
JQuery团队打造的javascript单元测试工具QUnit介绍
2010/02/26 Javascript
js最简单的拖拽效果实现代码
2010/09/24 Javascript
JavaScrip调试技巧之断点调试
2015/10/22 Javascript
微信小程序如何像vue一样在动态绑定类名
2018/04/17 Javascript
angularjs的单选框+ng-repeat的实现方法
2018/09/12 Javascript
Jquery高级应用Deferred对象原理及使用实例
2020/05/28 jQuery
js实现数字跳动到指定数字
2020/08/25 Javascript
[02:35]DOTA2超级联赛专访XB 难忘一年九冠称王
2013/06/20 DOTA
[01:02:45]完美世界DOTA2联赛 LBZS vs Forest 第三场 11.07
2020/11/09 DOTA
python编码最佳实践之总结
2016/02/14 Python
浅析python的Lambda表达式
2019/02/27 Python
Python3 chardet模块查看编码格式的例子
2019/08/14 Python
Python嵌套函数,作用域与偏函数用法实例分析
2019/12/26 Python
益模软件Java笔试题
2012/03/27 面试题
最新个人职业生涯规划书
2014/01/22 职场文书
副校长竞聘演讲稿
2014/09/01 职场文书
党的群众路线教育实践活动党员个人剖析材料
2014/10/08 职场文书
乡镇群众路线专项整治方案
2014/11/03 职场文书
网络营销计划书
2015/01/17 职场文书
公司禁烟通知
2015/04/23 职场文书
人生遥控器观后感
2015/06/11 职场文书
新年寄语2016
2015/08/17 职场文书
升学宴学生致辞
2015/09/29 职场文书
前端监听websocket消息并实时弹出(实例代码)
2021/11/27 Javascript
Redis超详细讲解高可用主从复制基础与哨兵模式方案
2022/04/07 Redis
MyBatis XPathParser解析器使用范例详解
2022/07/15 Java/Android