TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Web服务器Tornado使用小结
May 06 Python
Python标准库内置函数complex介绍
Nov 25 Python
Python压缩和解压缩zip文件
Feb 14 Python
python Flask实现restful api service
Dec 04 Python
对Python中list的倒序索引和切片实例讲解
Nov 15 Python
Python设计模式之建造者模式实例详解
Jan 17 Python
在PyCharm中控制台输出日志分层级分颜色显示的方法
Jul 11 Python
Python 项目转化为so文件实例
Dec 23 Python
Python urlopen()和urlretrieve()用法解析
Jan 07 Python
python 实现在shell窗口中编写print不向屏幕输出
Feb 19 Python
python 伯努利分布详解
Feb 25 Python
Python3.9.1中使用match方法详解
Feb 08 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
php注入实例
2006/10/09 PHP
以文本方式上传二进制文件的PHP程序
2006/10/09 PHP
如何在PHP中使用Oracle数据库(4)
2006/10/09 PHP
PHP中spl_autoload_register函数的用法总结
2013/11/07 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(九)
2014/06/24 PHP
使用php自动备份数据库表的实现方法
2017/07/28 PHP
一个简单的js鼠标划过切换效果
2010/06/30 Javascript
Extjs中TabPane如何嵌套在其他网页中实现思路及代码
2013/01/27 Javascript
NODE.JS加密模块CRYPTO常用方法介绍
2014/06/05 Javascript
jQuery中prop()方法用法实例
2015/01/05 Javascript
javascript实现图像循环明暗变化的方法
2015/02/25 Javascript
JavaScript实现广告的关闭与显示效果实例
2015/07/02 Javascript
JavaScript实现点击按钮直接打印
2016/01/06 Javascript
浅析JS操作DOM的一些常用方法
2016/05/13 Javascript
利用vue实现模态框组件
2016/12/19 Javascript
浅谈关于axios和session的一些事
2017/07/13 Javascript
微信小程序使用setData修改数组中单个对象的方法分析
2018/12/30 Javascript
JavaScript中.min.js和.js文件的区别讲解
2019/02/13 Javascript
[03:24]CDEC.Y赛前采访 努力备战2016国际邀请赛中国区预选赛
2016/06/25 DOTA
[00:12]DAC SOLO赛卫冕冠军 VG.Paparazi灬展现SOLO技巧
2018/04/06 DOTA
[48:28]完美世界DOTA2联赛循环赛FTD vs Magma第二场 10月30日
2020/10/31 DOTA
对python csv模块配置分隔符和引用符详解
2018/12/12 Python
PyQt5实现暗黑风格的计时器
2019/07/29 Python
pytorch实现特殊的Module--Sqeuential三种写法
2020/01/15 Python
Python操作Jira库常用方法解析
2020/04/10 Python
全球知名旅游社区巴西站点:TripAdvisor巴西
2016/07/21 全球购物
Joe Fresh官网:加拿大时尚品牌和零售连锁店
2016/11/30 全球购物
西班牙著名的珠宝首饰品牌:P D PAOLA
2018/09/15 全球购物
数据库设计的包括哪两种,请分别进行说明
2016/07/15 面试题
某公司.Net方向面试题
2014/04/24 面试题
大学生开西餐厅创业计划书
2014/02/01 职场文书
核心价值观演讲稿
2014/05/13 职场文书
2014年保洁员工作总结
2014/11/19 职场文书
2014年大班保育员工作总结
2014/12/02 职场文书
2015年为民办实事工作总结
2015/05/26 职场文书
2016幼儿园教师节新闻稿
2015/11/25 职场文书