TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
把MySQL表结构映射为Python中的对象的教程
Apr 07 Python
python通过zabbix api获取主机
Sep 17 Python
一百多行python代码实现抢票助手
Sep 25 Python
浅谈解除装饰器作用(python3新增)
Oct 15 Python
Python线程池模块ThreadPoolExecutor用法分析
Dec 28 Python
scrapy-redis的安装部署步骤讲解
Feb 27 Python
详解Python 切片语法
Jun 10 Python
python logging添加filter教程
Dec 24 Python
python数据库编程 ODBC方式实现通讯录
Mar 27 Python
Django添加bootstrap框架时无法加载静态文件的解决方式
Mar 27 Python
使用openCV去除文字中乱入的线条实例
Jun 02 Python
Python Pandas读取Excel日期数据的异常处理方法
Feb 28 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
PHP 程序员也要学会使用“异常”
2009/06/16 PHP
PHP获取http请求的头信息实现步骤
2012/12/16 PHP
php读取目录及子目录下所有文件名的方法
2014/10/20 PHP
Linux系统下PHP-FPM的安装和配置教程
2015/08/17 PHP
phpmailer简单发送邮件的方法(附phpmailer源码下载)
2016/06/13 PHP
php 可变函数使用小结
2018/06/12 PHP
js中的cookie的读写操作示例详解
2014/04/17 Javascript
JS实现漂亮的淡蓝色滑动门效果代码
2015/09/23 Javascript
Javascript刷新窗口方法小结
2015/10/21 Javascript
jQuery实现的给图片点赞+1动画效果(附在线演示及demo源码下载)
2015/12/31 Javascript
Angular2表单自定义验证器的实现
2016/10/19 Javascript
简单实现js轮播图效果
2017/07/14 Javascript
vue实现PC端录音功能的实例代码
2019/06/05 Javascript
Nodejs监控事件循环异常示例详解
2019/09/22 NodeJs
微信小程序网络请求实现过程解析
2019/11/06 Javascript
JS实现放大镜效果
2020/09/21 Javascript
Python中设置变量作为默认值时容易遇到的错误
2015/04/03 Python
python函数式编程学习之yield表达式形式详解
2018/03/25 Python
python 编码规范整理
2018/05/05 Python
pycharm使用matplotlib.pyplot不显示图形的解决方法
2018/10/28 Python
对Python协程之异步同步的区别详解
2019/02/19 Python
关于keras.layers.Conv1D的kernel_size参数使用介绍
2020/05/22 Python
Python 使用双重循环打印图形菱形操作
2020/08/09 Python
Python直接赋值及深浅拷贝原理详解
2020/09/05 Python
基于Python爬取搜狐证券股票过程解析
2020/11/18 Python
用python制作个音乐下载器
2021/01/30 Python
Html5游戏开发之乒乓Ping Pong游戏示例(二)
2013/01/21 HTML / CSS
惠普美国官方商店:HP Official Store
2016/08/28 全球购物
英国Amara家居法国网站:家居装饰,现代装饰和豪华礼品
2016/12/15 全球购物
C#的几个面试问题
2016/05/22 面试题
医学生实习自我鉴定
2013/09/27 职场文书
合作意向协议书
2015/01/29 职场文书
2019年公司卫生管理制度样本
2019/08/21 职场文书
浅谈Redis中的RDB快照
2021/06/29 Redis
部分武汉产收音机展览
2022/04/07 无线电
vue项目proxyTable配置和部署服务器
2022/04/14 Vue.js