TensorFlow如何实现反向传播


Posted in Python onFebruary 06, 2018

使用TensorFlow的一个优势是,它可以维护操作状态和基于反向传播自动地更新模型变量。
TensorFlow通过计算图来更新变量和最小化损失函数来反向传播误差的。这步将通过声明优化函数(optimization function)来实现。一旦声明好优化函数,TensorFlow将通过它在所有的计算图中解决反向传播的项。当我们传入数据,最小化损失函数,TensorFlow会在计算图中根据状态相应的调节变量。

回归算法的例子从均值为1、标准差为0.1的正态分布中抽样随机数,然后乘以变量A,损失函数为L2正则损失函数。理论上,A的最优值是10,因为生成的样例数据均值是1。

二个例子是一个简单的二值分类算法。从两个正态分布(N(-1,1)和N(3,1))生成100个数。所有从正态分布N(-1,1)生成的数据标为目标类0;从正态分布N(3,1)生成的数据标为目标类1,模型算法通过sigmoid函数将这些生成的数据转换成目标类数据。换句话讲,模型算法是sigmoid(x+A),其中,A是要拟合的变量,理论上A=-1。假设,两个正态分布的均值分别是m1和m2,则达到A的取值时,它们通过-(m1+m2)/2转换成到0等距的值。后面将会在TensorFlow中见证怎样取到相应的值。

同时,指定一个合适的学习率对机器学习算法的收敛是有帮助的。优化器类型也需要指定,前面的两个例子会使用标准梯度下降法,它在TensorFlow中的实现是GradientDescentOptimizer()函数。

# 反向传播
#----------------------------------
#
# 以下Python函数主要是展示回归和分类模型的反向传播

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 创建计算图会话
sess = tf.Session()

# 回归算法的例子:
# We will create sample data as follows:
# x-data: 100 random samples from a normal ~ N(1, 0.1)
# target: 100 values of the value 10.
# We will fit the model:
# x-data * A = target
# Theoretically, A = 10.

# 生成数据,创建占位符和变量A
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# Create variable (one model parameter = A)
A = tf.Variable(tf.random_normal(shape=[1]))

# 增加乘法操作
my_output = tf.multiply(x_data, A)

# 增加L2正则损失函数
loss = tf.square(my_output - y_target)

# 在运行优化器之前,需要初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 声明变量的优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 训练算法
for i in range(100):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})))

# 分类算法例子
# We will create sample data as follows:
# x-data: sample 50 random values from a normal = N(-1, 1)
#     + sample 50 random values from a normal = N(1, 1)
# target: 50 values of 0 + 50 values of 1.
#     These are essentially 100 values of the corresponding output index
# We will fit the binary classification model:
# If sigmoid(x+A) < 0.5 -> 0 else 1
# Theoretically, A should be -(mean1 + mean2)/2

# 重置计算图
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# 生成数据
x_vals = np.concatenate((np.random.normal(-1, 1, 50), np.random.normal(3, 1, 50)))
y_vals = np.concatenate((np.repeat(0., 50), np.repeat(1., 50)))
x_data = tf.placeholder(shape=[1], dtype=tf.float32)
y_target = tf.placeholder(shape=[1], dtype=tf.float32)

# 偏差变量A (one model parameter = A)
A = tf.Variable(tf.random_normal(mean=10, shape=[1]))

# 增加转换操作
# Want to create the operstion sigmoid(x + A)
# Note, the sigmoid() part is in the loss function
my_output = tf.add(x_data, A)

# 由于指定的损失函数期望批量数据增加一个批量数的维度
# 这里使用expand_dims()函数增加维度
my_output_expanded = tf.expand_dims(my_output, 0)
y_target_expanded = tf.expand_dims(y_target, 0)

# 初始化变量A
init = tf.global_variables_initializer()
sess.run(init)

# 声明损失函数 交叉熵(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output_expanded, labels=y_target_expanded)

# 增加一个优化器函数 让TensorFlow知道如何更新和偏差变量
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 迭代
for i in range(1400):
  rand_index = np.random.choice(100)
  rand_x = [x_vals[rand_index]]
  rand_y = [y_vals[rand_index]]

  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)))
    print('Loss = ' + str(sess.run(xentropy, feed_dict={x_data: rand_x, y_target: rand_y})))

# 评估预测
predictions = []
for i in range(len(x_vals)):
  x_val = [x_vals[i]]
  prediction = sess.run(tf.round(tf.sigmoid(my_output)), feed_dict={x_data: x_val})
  predictions.append(prediction[0])

accuracy = sum(x==y for x,y in zip(predictions, y_vals))/100.
print('最终精确度 = ' + str(np.round(accuracy, 2)))

输出:

Step #25 A = [ 6.12853956]
Loss = [ 16.45088196]
Step #50 A = [ 8.55680943]
Loss = [ 2.18415046]
Step #75 A = [ 9.50547695]
Loss = [ 5.29813051]
Step #100 A = [ 9.89214897]
Loss = [ 0.34628963]
Step #200 A = [ 3.84576249]
Loss = [[ 0.00083012]]
Step #400 A = [ 0.42345378]
Loss = [[ 0.01165466]]
Step #600 A = [-0.35141727]
Loss = [[ 0.05375391]]
Step #800 A = [-0.74206048]
Loss = [[ 0.05468176]]
Step #1000 A = [-0.89036471]
Loss = [[ 0.19636908]]
Step #1200 A = [-0.90850282]
Loss = [[ 0.00608062]]
Step #1400 A = [-1.09374011]
Loss = [[ 0.11037558]]
最终精确度 = 1.0

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之函数对象(函数也是对象)
Aug 30 Python
python中ConfigParse模块的用法
Sep 29 Python
Python中列表的一些基本操作知识汇总
May 20 Python
python从网络读取图片并直接进行处理的方法
May 22 Python
python获得文件创建时间和修改时间的方法
Jun 30 Python
python文件名和文件路径操作实例
Sep 29 Python
Python 实现选择排序的算法步骤
Apr 22 Python
详解从Django Rest Framework响应中删除空字段
Jan 11 Python
pyqt5 实现多窗口跳转的方法
Jun 19 Python
python使用原始套接字发送二层包(链路层帧)的方法
Jul 22 Python
使用OpenCV-python3实现滑动条更新图像的Canny边缘检测功能
Dec 12 Python
Python拼接字符串的7种方式详解
Mar 19 Python
tensorflow TFRecords文件的生成和读取的方法
Feb 06 #Python
TensorFlow实现创建分类器
Feb 06 #Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
You might like
ThinkPHP 防止表单重复提交的方法
2011/08/08 PHP
Zend Framework 2.0事件管理器(The EventManager)入门教程
2014/08/11 PHP
php实现约瑟夫问题的方法小结
2015/03/23 PHP
php实现encode64编码类实例
2015/03/24 PHP
实例解析php的数据类型
2018/10/24 PHP
Laravel框架处理用户的请求操作详解
2019/12/20 PHP
PHP过滤器 filter_has_var() 函数用法实例分析
2020/04/23 PHP
JS array 数组详解
2009/03/22 Javascript
解析Jquery的LigerUI如何实现文件上传
2013/07/09 Javascript
display和visibility的区别示例介绍
2014/02/26 Javascript
js给网页加上背景音乐及选择音效的方法
2015/03/03 Javascript
详解JavaScript正则表达式之RegExp对象
2015/12/13 Javascript
Node.js DES加密的简单实现
2016/07/07 Javascript
微信小程序开发之圆形菜单 仿建行圆形菜单实例
2016/12/12 Javascript
jQuery Ajax自定义分页组件(jquery.loehpagerv1.0)实例详解
2017/05/01 jQuery
js自定义瀑布流布局插件
2017/05/16 Javascript
基于canvas粒子系统的构建详解
2017/08/31 Javascript
使用Fullpage插件快速开发整屏翻页的页面
2017/09/13 Javascript
微信小程序实现图片翻转效果的实例代码
2019/09/20 Javascript
在Vue中获取自定义属性方法:data-id的实例
2020/09/09 Javascript
vue+springboot+element+vue-resource实现文件上传教程
2020/10/21 Javascript
[00:12]DAC2018 Miracle-站上中单舞台,他能否再写奇迹?
2018/04/06 DOTA
[42:52]Optic vs Serenity 2018国际邀请赛淘汰赛BO3 第二场 8.22
2018/08/23 DOTA
[00:56]PWL开团时刻DAY8——追追追追追!
2020/11/09 DOTA
在Django框架中伪造捕捉到的URLconf值的方法
2015/07/18 Python
浅谈python正则的常用方法 覆盖范围70%以上
2018/03/14 Python
python将txt文档每行内容循环插入数据库的方法
2018/12/28 Python
PyQt5基本控件使用详解:单选按钮、复选框、下拉框
2019/08/05 Python
python常用数据重复项处理方法
2019/11/22 Python
Tensorflow 多线程设置方式
2020/02/06 Python
Django关于admin的使用技巧和知识点
2020/02/10 Python
keras的siamese(孪生网络)实现案例
2020/06/12 Python
会计学自我鉴定
2014/02/06 职场文书
军训自我鉴定100字
2014/02/13 职场文书
2015年暑期社会实践总结
2015/07/13 职场文书
管理者们如何制定2019年的工作计划?
2019/07/01 职场文书