Tensorflow使用支持向量机拟合线性回归


Posted in Python onSeptember 07, 2018

支持向量机可以用来拟合线性回归。

相同的最大间隔(maximum margin)的概念应用到线性回归拟合。代替最大化分割两类目标是,最大化分割包含大部分的数据点(x,y)。我们将用相同的iris数据集,展示用刚才的概念来进行花萼长度与花瓣宽度之间的线性拟合。

相关的损失函数类似于max(0,|yi-(Axi+b)|-ε)。ε这里,是间隔宽度的一半,这意味着如果一个数据点在该区域,则损失等于0。

# SVM Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve support vector regression. We are going
# to find the line that has the maximum margin
# which INCLUDES as many points as possible
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Pedal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# Split data into train/test sets
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function
# = max(0, abs(target - predicted) + epsilon)
# 1/2 margin width parameter = epsilon
epsilon = tf.constant([0.5])
# Margin term in loss
loss = tf.reduce_mean(tf.maximum(0., tf.subtract(tf.abs(tf.subtract(model_output, y_target)), epsilon)))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.075)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
train_loss = []
test_loss = []
for i in range(200):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = np.transpose([x_vals_train[rand_index]])
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_train_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_train]), y_target: np.transpose([y_vals_train])})
  train_loss.append(temp_train_loss)

  temp_test_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_test]), y_target: np.transpose([y_vals_test])})
  test_loss.append(temp_test_loss)
  if (i+1)%50==0:
    print('-----------')
    print('Generation: ' + str(i+1))
    print('A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Train Loss = ' + str(temp_train_loss))
    print('Test Loss = ' + str(temp_test_loss))

# Extract Coefficients
[[slope]] = sess.run(A)
[[y_intercept]] = sess.run(b)
[width] = sess.run(epsilon)

# Get best fit line
best_fit = []
best_fit_upper = []
best_fit_lower = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)
 best_fit_upper.append(slope*i+y_intercept+width)
 best_fit_lower.append(slope*i+y_intercept-width)

# Plot fit with data
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='SVM Regression Line', linewidth=3)
plt.plot(x_vals, best_fit_upper, 'r--', linewidth=2)
plt.plot(x_vals, best_fit_lower, 'r--', linewidth=2)
plt.ylim([0, 10])
plt.legend(loc='lower right')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(train_loss, 'k-', label='Train Set Loss')
plt.plot(test_loss, 'r--', label='Test Set Loss')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.legend(loc='upper right')
plt.show()

输出结果:

-----------
Generation: 50
A = [[ 2.91328382]] b = [[ 1.18453276]]
Train Loss = 1.17104
Test Loss = 1.1143
-----------
Generation: 100
A = [[ 2.42788291]] b = [[ 2.3755331]]
Train Loss = 0.703519
Test Loss = 0.715295
-----------
Generation: 150
A = [[ 1.84078252]] b = [[ 3.40453291]]
Train Loss = 0.338596
Test Loss = 0.365562
-----------
Generation: 200
A = [[ 1.35343242]] b = [[ 4.14853334]]
Train Loss = 0.125198
Test Loss = 0.16121

 

Tensorflow使用支持向量机拟合线性回归

基于iris数据集(花萼长度和花瓣宽度)的支持向量机回归,间隔宽度为0.5

Tensorflow使用支持向量机拟合线性回归

每次迭代的支持向量机回归的损失值(训练集和测试集)

直观地讲,我们认为SVM回归算法试图把更多的数据点拟合到直线两边2ε宽度的间隔内。这时拟合的直线对于ε参数更有意义。如果选择太小的ε值,SVM回归算法在间隔宽度内不能拟合更多的数据点;如果选择太大的ε值,将有许多条直线能够在间隔宽度内拟合所有的数据点。作者更倾向于选取更小的ε值,因为在间隔宽度附近的数据点比远处的数据点贡献更少的损失。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Node.js和Socket.IO扩展Django的实时处理功能
Apr 20 Python
Python使用迭代器打印螺旋矩阵的思路及代码示例
Jul 02 Python
python 生成器生成杨辉三角的方法(必看)
Apr 10 Python
Python FTP两个文件夹间的同步实例代码
May 25 Python
Python动态生成多维数组的方法示例
Aug 09 Python
Python3数字求和的实例
Feb 19 Python
PythonWeb项目Django部署在Ubuntu18.04腾讯云主机上
Apr 01 Python
Python TCPServer 多线程多客户端通信的实现
Dec 31 Python
PyTorch 随机数生成占用 CPU 过高的解决方法
Jan 13 Python
详解有关PyCharm安装库失败的问题的解决方法
Feb 02 Python
Python日志:自定义输出字段 json格式输出方式
Apr 27 Python
快速创建python 虚拟环境
Nov 28 Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
You might like
PHP验证终端类型是否为手机的简单实例
2017/02/07 PHP
JavaScript让IE浏览器event对象符合W3C DOM标准
2009/11/24 Javascript
jquery异步循环获取功能实现代码
2010/09/19 Javascript
解决jQuery插件tipswindown与hintbox冲突
2010/11/05 Javascript
jquery获取复选框被选中的值
2014/03/22 Javascript
jQuery中serializeArray()与serialize()的区别实例分析
2015/12/09 Javascript
复杂的javascript窗口分帧解析
2016/02/19 Javascript
完美解决jQuery 鼠标快速滑过后,会执行多次滑出的问题
2016/12/08 Javascript
教大家轻松制作Bootstrap漂亮表格(table)
2016/12/13 Javascript
原生js实现新闻列表展开/收起全文功能
2017/01/20 Javascript
Node.JS文件系统解析实例详解
2017/05/15 Javascript
html中通过JS获取JSON数据并加载的方法
2017/11/30 Javascript
详解ES6语法之可迭代协议和迭代器协议
2018/01/13 Javascript
jQuery实现的淡入淡出图片轮播效果示例
2018/08/29 jQuery
layui 弹出删除确认界面的实例
2019/09/06 Javascript
JavaScript算法学习之冒泡排序和选择排序
2019/11/02 Javascript
Windows下用py2exe将Python程序打包成exe程序的教程
2015/04/08 Python
python中学习K-Means和图片压缩
2017/11/20 Python
详解在python操作数据库中游标的使用方法
2019/11/12 Python
selenium+python实现自动登陆QQ邮箱并发送邮件功能
2019/12/13 Python
解决Keras中Embedding层masking与Concatenate层不可调和的问题
2020/06/18 Python
Python优秀开源项目Rich源码解析的流程分析
2020/07/06 Python
Python txt文件常用读写操作代码实例
2020/08/03 Python
会计辞职信范文
2014/01/15 职场文书
四川成都导游欢迎词
2014/01/18 职场文书
运动会广播稿400字
2014/01/25 职场文书
大专毕业自我鉴定
2014/02/04 职场文书
就业自我评价
2014/02/04 职场文书
优秀毕业生就业推荐信
2014/05/22 职场文书
2014年无财产无子女离婚协议书范本
2014/10/09 职场文书
红色故事汇观后感
2015/06/18 职场文书
七年级作文(600字3篇)
2019/09/24 职场文书
Ajax是什么?Ajax高级用法之Axios技术
2021/04/21 Javascript
详解overflow:hidden的作用(溢出隐藏、清除浮动、解决外边距塌陷)
2021/07/01 HTML / CSS
css3中2D转换之有趣的transform形变效果
2022/02/24 HTML / CSS
避坑之 JavaScript 中的toFixed()和正则表达式
2022/04/19 Javascript