用TensorFlow实现戴明回归算法的示例


Posted in Python onMay 02, 2018

如果最小二乘线性回归算法最小化到回归直线的竖直距离(即,平行于y轴方向),则戴明回归最小化到回归直线的总距离(即,垂直于回归直线)。其最小化x值和y值两个方向的误差,具体的对比图如下图。

用TensorFlow实现戴明回归算法的示例 

线性回归算法和戴明回归算法的区别。左边的线性回归最小化到回归直线的竖直距离;右边的戴明回归最小化到回归直线的总距离。

线性回归算法的损失函数最小化竖直距离;而这里需要最小化总距离。给定直线的斜率和截距,则求解一个点到直线的垂直距离有已知的几何公式。代入几何公式并使TensorFlow最小化距离。

损失函数是由分子和分母组成的几何公式。给定直线y=mx+b,点(x0,y0),则求两者间的距离的公式为:

用TensorFlow实现戴明回归算法的示例

# 戴明回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear Deming regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare Demming loss function
demming_numerator = tf.abs(tf.subtract(y_target, tf.add(tf.matmul(x_data, A), b)))
demming_denominator = tf.sqrt(tf.add(tf.square(A),1))
loss = tf.reduce_mean(tf.truediv(demming_numerator, demming_denominator))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.1)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(250):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%50==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# Plot the result
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

用TensorFlow实现戴明回归算法的示例 

用TensorFlow实现戴明回归算法的示例 

本文的戴明回归算法与线性回归算法得到的结果基本一致。两者之间的关键不同点在于预测值与数据点间的损失函数度量:线性回归算法的损失函数是竖直距离损失;而戴明回归算法是垂直距离损失(到x轴和y轴的总距离损失)。

注意,这里戴明回归算法的实现类型是总体回归(总的最小二乘法误差)。总体回归算法是假设x值和y值的误差是相似的。我们也可以根据不同的理念使用不同的误差来扩展x轴和y轴的距离计算。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用SQLite的简单教程
Apr 29 Python
Python基于smtplib实现异步发送邮件服务
May 28 Python
举例讲解Python的lambda语句声明匿名函数的用法
Jul 01 Python
python 打印对象的所有属性值的方法
Sep 11 Python
python中matplotlib的颜色及线条控制的示例
Mar 16 Python
python 实现UTC时间加减的方法
Dec 31 Python
postman传递当前时间戳实例详解
Sep 14 Python
Python 模拟动态产生字母验证码图片功能
Dec 24 Python
在Django中预防CSRF攻击的操作
Mar 13 Python
计算Python Numpy向量之间的欧氏距离实例
May 22 Python
Python连接mysql方法及常用参数
Sep 01 Python
python 中的@运算符使用
May 26 Python
用TensorFlow实现lasso回归和岭回归算法的示例
May 02 #Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
You might like
全国FM电台频率大全 - 25 云南省
2020/03/11 无线电
10条PHP高级技巧[修正版]
2011/08/02 PHP
深入浅析PHP无限极分类的案例教程
2016/05/09 PHP
php 访问oracle 存储过程实例详解
2017/01/08 PHP
PHP长网址与短网址的实现方法
2017/10/13 PHP
TP5(thinkPHP5)框架使用ajax实现与后台数据交互的方法小结
2020/02/10 PHP
JavaScript国旗变换效果代码
2008/08/13 Javascript
window.ActiveXObject使用说明
2010/11/08 Javascript
使用jquery操作session方法分享
2015/01/22 Javascript
JavaScript使用replace函数替换字符串的方法
2015/04/06 Javascript
详解参数传递四种形式
2015/07/21 Javascript
JS实现左右拖动改变内容显示区域大小的方法
2015/10/13 Javascript
jquery二级目录选中当前页的css样式
2016/12/08 Javascript
移动端脚本框架Hammer.js
2016/12/15 Javascript
学习 NodeJS 第八天:Socket 通讯实例
2016/12/21 NodeJs
js控制一个按钮是否可点击(可使用)disabled的实例
2017/02/14 Javascript
Thinkphp5微信小程序获取用户信息接口的实例详解
2017/09/26 Javascript
JavaScript中var、let、const区别浅析
2018/06/24 Javascript
Vue实现textarea固定输入行数与添加下划线样式的思路详解
2018/06/28 Javascript
vue+iview+less 实现换肤功能
2018/08/17 Javascript
Vuejs监听vuex中值的变化的方法示例
2018/12/02 Javascript
js实现unicode码字符串与utf8字节数据互转详解
2019/03/21 Javascript
详解ES6 Promise的生命周期和创建
2019/08/18 Javascript
Javascript实现html转pdf高清版(提高分辨率)
2020/02/19 Javascript
python爬虫常用的模块分析
2014/08/29 Python
如何使用python爬取csdn博客访问量
2016/02/14 Python
Python进程间通信Queue实例解析
2018/01/25 Python
PyCharm中代码字体大小调整方法
2019/07/29 Python
利用html5 canvas动态画饼状图的示例代码
2018/04/02 HTML / CSS
新大陆软件面试题
2016/11/24 面试题
会计应聘求职信范文
2013/12/17 职场文书
库房主管岗位职责
2013/12/31 职场文书
商铺租赁意向书
2014/04/01 职场文书
golang中的空接口使用详解
2021/03/30 Python
记一次Mysql不走日期字段索引的原因小结
2021/10/24 MySQL
Java 通过手写分布式雪花SnowFlake生成ID方法详解
2022/04/07 Java/Android