TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的简单文件传输服务器和客户端
Apr 08 Python
python实现提取百度搜索结果的方法
May 19 Python
Python带动态参数功能的sqlite工具类
May 26 Python
tensorflow: 查看 tensor详细数值方法
Jun 13 Python
python的concat等多种用法详解
Nov 28 Python
pycharm 实现显示project 选项卡的方法
Jan 17 Python
Pyqt5 实现跳转界面并关闭当前界面的方法
Jun 19 Python
keras模型可视化,层可视化及kernel可视化实例
Jan 24 Python
基于Python3.6中的OpenCV实现图片色彩空间的转换
Feb 03 Python
如何通过Python3和ssl实现加密通信功能
May 09 Python
Python实现石头剪刀布游戏
Jan 20 Python
Python使用sql语句对mysql数据库多条件模糊查询的思路详解
Apr 12 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
php公用函数列表[正则]
2007/02/22 PHP
Fine Uploader文件上传组件应用介绍
2013/01/06 PHP
Destoon旺旺无法正常显示,点击提示“会员名不存在”的解决办法
2014/06/21 PHP
php简单统计在线人数的方法
2016/05/10 PHP
PHP实现类似于C语言的文件读取及解析功能
2017/09/01 PHP
基于swoole实现多人聊天室
2018/06/14 PHP
php文件上传原理与实现方法详解
2019/12/20 PHP
Javascript Request获取请求参数如何实现
2012/11/28 Javascript
浅谈Javascript中的12种DOM节点类型
2016/08/19 Javascript
js放大镜放大购物图片效果
2017/01/18 Javascript
bootstrapValidator 重新启用提交按钮的方法
2017/02/20 Javascript
详解node nvm进行node多版本管理
2017/10/21 Javascript
浅谈vue中使用图片懒加载vue-lazyload插件详细指南
2017/10/23 Javascript
layui select动态添加option的实例
2018/03/07 Javascript
vue如何安装使用Quill富文本编辑器
2018/09/21 Javascript
Vue基本使用之对象提供的属性功能
2019/04/30 Javascript
JavaScript实现京东放大镜效果
2019/12/03 Javascript
使用preload预加载页面资源时注意事项
2020/02/03 Javascript
Python中decorator使用实例
2015/04/14 Python
Python多进程分块读取超大文件的方法
2016/04/13 Python
python实现redis三种cas事务操作
2017/12/19 Python
基于numpy中数组元素的切片复制方法
2018/11/15 Python
python小程序基于Jupyter实现天气查询的方法
2020/03/27 Python
keras和tensorflow使用fit_generator 批次训练操作
2020/07/03 Python
canvas实现漂亮的下雨效果的示例
2018/04/18 HTML / CSS
英国乐购杂货:Tesco Groceries
2018/11/29 全球购物
亿企通软件测试面试题
2012/04/10 面试题
食品安全检查制度
2014/02/03 职场文书
《蜗牛》教学反思
2014/02/18 职场文书
财产公证书
2014/04/10 职场文书
竞选班长演讲稿500字
2014/08/22 职场文书
五五普法心得体会
2014/09/04 职场文书
会计工作自我鉴定范文
2019/06/21 职场文书
Python实现随机生成迷宫并自动寻路
2021/06/13 Python
Python接口自动化之文件上传/下载接口详解
2022/04/05 Python
python微信智能AI机器人实现多种支付方式
2022/04/12 Python