Tensorflow使用支持向量机拟合线性回归


Posted in Python onSeptember 07, 2018

支持向量机可以用来拟合线性回归。

相同的最大间隔(maximum margin)的概念应用到线性回归拟合。代替最大化分割两类目标是,最大化分割包含大部分的数据点(x,y)。我们将用相同的iris数据集,展示用刚才的概念来进行花萼长度与花瓣宽度之间的线性拟合。

相关的损失函数类似于max(0,|yi-(Axi+b)|-ε)。ε这里,是间隔宽度的一半,这意味着如果一个数据点在该区域,则损失等于0。

# SVM Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve support vector regression. We are going
# to find the line that has the maximum margin
# which INCLUDES as many points as possible
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Pedal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# Split data into train/test sets
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function
# = max(0, abs(target - predicted) + epsilon)
# 1/2 margin width parameter = epsilon
epsilon = tf.constant([0.5])
# Margin term in loss
loss = tf.reduce_mean(tf.maximum(0., tf.subtract(tf.abs(tf.subtract(model_output, y_target)), epsilon)))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.075)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
train_loss = []
test_loss = []
for i in range(200):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = np.transpose([x_vals_train[rand_index]])
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_train_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_train]), y_target: np.transpose([y_vals_train])})
  train_loss.append(temp_train_loss)

  temp_test_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_test]), y_target: np.transpose([y_vals_test])})
  test_loss.append(temp_test_loss)
  if (i+1)%50==0:
    print('-----------')
    print('Generation: ' + str(i+1))
    print('A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Train Loss = ' + str(temp_train_loss))
    print('Test Loss = ' + str(temp_test_loss))

# Extract Coefficients
[[slope]] = sess.run(A)
[[y_intercept]] = sess.run(b)
[width] = sess.run(epsilon)

# Get best fit line
best_fit = []
best_fit_upper = []
best_fit_lower = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)
 best_fit_upper.append(slope*i+y_intercept+width)
 best_fit_lower.append(slope*i+y_intercept-width)

# Plot fit with data
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='SVM Regression Line', linewidth=3)
plt.plot(x_vals, best_fit_upper, 'r--', linewidth=2)
plt.plot(x_vals, best_fit_lower, 'r--', linewidth=2)
plt.ylim([0, 10])
plt.legend(loc='lower right')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(train_loss, 'k-', label='Train Set Loss')
plt.plot(test_loss, 'r--', label='Test Set Loss')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.legend(loc='upper right')
plt.show()

输出结果:

-----------
Generation: 50
A = [[ 2.91328382]] b = [[ 1.18453276]]
Train Loss = 1.17104
Test Loss = 1.1143
-----------
Generation: 100
A = [[ 2.42788291]] b = [[ 2.3755331]]
Train Loss = 0.703519
Test Loss = 0.715295
-----------
Generation: 150
A = [[ 1.84078252]] b = [[ 3.40453291]]
Train Loss = 0.338596
Test Loss = 0.365562
-----------
Generation: 200
A = [[ 1.35343242]] b = [[ 4.14853334]]
Train Loss = 0.125198
Test Loss = 0.16121

 

Tensorflow使用支持向量机拟合线性回归

基于iris数据集(花萼长度和花瓣宽度)的支持向量机回归,间隔宽度为0.5

Tensorflow使用支持向量机拟合线性回归

每次迭代的支持向量机回归的损失值(训练集和测试集)

直观地讲,我们认为SVM回归算法试图把更多的数据点拟合到直线两边2ε宽度的间隔内。这时拟合的直线对于ε参数更有意义。如果选择太小的ε值,SVM回归算法在间隔宽度内不能拟合更多的数据点;如果选择太大的ε值,将有许多条直线能够在间隔宽度内拟合所有的数据点。作者更倾向于选取更小的ε值,因为在间隔宽度附近的数据点比远处的数据点贡献更少的损失。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的Flask框架下收发电子邮件的教程
Apr 21 Python
简单讲解Python中的字符串与字符串的输入输出
Mar 13 Python
Python字符串处理实例详解
May 18 Python
python使用zip将list转为json的方法
Dec 31 Python
Python列表(list)所有元素的同一操作解析
Aug 01 Python
django 自定义filter 判断if var in list的例子
Aug 20 Python
Pytorch 多维数组运算过程的索引处理方式
Dec 27 Python
python如何遍历指定路径下所有文件(按按照时间区间检索)
Sep 14 Python
Python如何急速下载第三方库详解
Nov 02 Python
Python调用飞书发送消息的示例
Nov 10 Python
python实现ping命令小程序
Dec 28 Python
利用python实时刷新基金估值(摸鱼小工具)
Sep 15 Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
You might like
使用PHP和XSL stylesheets转换XML文档
2006/10/09 PHP
杏林同学录(五)
2006/10/09 PHP
使用Apache的htaccess防止图片被盗链的解决方法
2013/04/27 PHP
使用dump函数,给php加断点测试
2013/06/25 PHP
php字符串操作常见问题小结
2016/10/11 PHP
laravel框架与其他框架的详细对比
2019/10/23 PHP
理解JavaScript中的事件
2006/09/23 Javascript
用JS实现一个TreeMenu效果分享
2011/08/28 Javascript
jQuery 顶部导航跟随滚动条滚动固定浮动在顶部
2014/06/06 Javascript
JavaScript获取各大浏览器信息图示
2015/11/20 Javascript
第一章之初识Bootstrap
2016/04/25 Javascript
javascript事件的绑定基础实例讲解(34)
2017/02/14 Javascript
详细分析单线程JS执行问题
2017/11/22 Javascript
JavaScript函数节流和函数去抖知识点学习
2018/07/31 Javascript
Windows下Node爬虫神器Puppeteer安装记
2019/01/09 Javascript
详解vue.js移动端配置flexible.js及注意事项
2019/04/10 Javascript
使用layer弹窗提交表单时判断表单是否输入为空的例子
2019/09/26 Javascript
详解Vue中的MVVM原理和实现方法
2020/07/15 Javascript
Vue $emit()不能触发父组件方法的原因及解决
2020/07/28 Javascript
[02:32]“虐狗”镜头慎点 2016国际邀请赛中国区预选赛现场玩家采访
2016/06/28 DOTA
python使用cookie库操保存cookie详解
2014/03/03 Python
python append、extend与insert的区别
2016/10/13 Python
python中WSGI是什么,Python应用WSGI详解
2017/11/24 Python
基于MSELoss()与CrossEntropyLoss()的区别详解
2020/01/02 Python
手把手教你进行Python虚拟环境配置教程
2020/02/03 Python
Python字典添加,删除,查询等相关操作方法详解
2020/02/07 Python
Python猫眼电影最近上映的电影票房信息
2020/09/18 Python
Python安装Bs4的多种方法
2020/11/28 Python
CSS3 制作旋转的大风车(充满童年回忆)
2013/01/30 HTML / CSS
机电专业大学生求职信
2013/10/04 职场文书
网络技术专业求职信
2014/07/13 职场文书
护士先进个人总结
2015/02/13 职场文书
开天辟地观后感
2015/06/09 职场文书
小学教师教学反思
2016/02/24 职场文书
励志语录:只有自己足够强大,才能不被别人践踏
2020/01/09 职场文书
MySQL COUNT函数的使用与优化
2021/05/10 MySQL