TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python转码问题的解决方法
Oct 07 Python
详解Python验证码识别
Jan 25 Python
Python爬虫设置代理IP的方法(爬虫技巧)
Mar 04 Python
Python多重继承的方法解析执行顺序实例分析
May 26 Python
python 查找文件名包含指定字符串的方法
Jun 05 Python
python中ASCII码和字符的转换方法
Jul 09 Python
python文件操作之批量修改文件后缀名的方法
Aug 10 Python
Pytorch在NLP中的简单应用详解
Jan 08 Python
Pytorch 保存模型生成图片方式
Jan 10 Python
Python request操作步骤及代码实例
Apr 13 Python
python绘图pyecharts+pandas的使用详解
Dec 13 Python
Python爬取网站图片并保存的实现示例
Feb 26 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
windows下开发并编译PHP扩展的方法
2011/03/18 PHP
php Smarty初体验二 获取配置信息
2011/08/08 PHP
php启用zlib压缩文件的配置方法
2013/06/12 PHP
php读取二进制流(C语言结构体struct数据文件)的深入解析
2013/06/13 PHP
php下获取http状态的实现代码
2014/05/09 PHP
Zend Framework入门应用实例详解
2016/12/11 PHP
yii使用bootstrap分页样式的实例
2017/01/17 PHP
推荐dojo学习笔记
2007/03/24 Javascript
一个JavaScript去除字符串末尾的空白实例代码
2014/09/22 Javascript
使用js画图之画切线
2015/01/12 Javascript
JavaScript实现控制打开文件另存为对话框的方法
2015/04/17 Javascript
js正则表达式replace替换变量方法
2016/05/21 Javascript
jQuery购物网页经典制作案例
2016/08/19 Javascript
BootstrapValidator不触发校验的实现代码
2016/09/28 Javascript
echart简介_动力节点Java学院整理
2017/08/11 Javascript
javascript中undefined的本质解析
2019/07/31 Javascript
Win10下python 2.7.13 安装配置方法图文教程
2018/09/18 Python
python 获取sqlite3数据库的表名和表字段名的实例
2019/07/17 Python
django的聚合函数和aggregate、annotate方法使用详解
2019/07/23 Python
python处理document文档保留原样式
2019/09/23 Python
Python 通过爬虫实现GitHub网页的模拟登录的示例代码
2020/08/17 Python
基于Python组装jmx并调用JMeter实现压力测试
2020/11/03 Python
python 批量下载bilibili视频的gui程序
2020/11/20 Python
python 用pandas实现数据透视表功能
2020/12/21 Python
css3 background属性调整增强介绍
2010/12/18 HTML / CSS
CSS3图片旋转特效(360/60/-360度)
2013/10/10 HTML / CSS
Skyscanner英国:苏格兰的全球三大领先航班搜索服务之一
2017/11/09 全球购物
泰国排名第一的家居用品中心:HomePro
2020/11/18 全球购物
程序员机试试题汇总
2012/03/07 面试题
使用C#编写创建一个线程的代码
2013/01/22 面试题
网站客服岗位职责
2014/04/05 职场文书
医疗器械售后服务承诺书
2014/05/21 职场文书
好的促销活动方案
2014/08/21 职场文书
高三毕业评语
2014/12/31 职场文书
黄埔军校观后感
2015/06/10 职场文书
导游词之崇武古城
2019/10/07 职场文书