Tensorflow使用支持向量机拟合线性回归


Posted in Python onSeptember 07, 2018

支持向量机可以用来拟合线性回归。

相同的最大间隔(maximum margin)的概念应用到线性回归拟合。代替最大化分割两类目标是,最大化分割包含大部分的数据点(x,y)。我们将用相同的iris数据集,展示用刚才的概念来进行花萼长度与花瓣宽度之间的线性拟合。

相关的损失函数类似于max(0,|yi-(Axi+b)|-ε)。ε这里,是间隔宽度的一半,这意味着如果一个数据点在该区域,则损失等于0。

# SVM Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve support vector regression. We are going
# to find the line that has the maximum margin
# which INCLUDES as many points as possible
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Pedal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# Split data into train/test sets
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function
# = max(0, abs(target - predicted) + epsilon)
# 1/2 margin width parameter = epsilon
epsilon = tf.constant([0.5])
# Margin term in loss
loss = tf.reduce_mean(tf.maximum(0., tf.subtract(tf.abs(tf.subtract(model_output, y_target)), epsilon)))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.075)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
train_loss = []
test_loss = []
for i in range(200):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = np.transpose([x_vals_train[rand_index]])
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_train_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_train]), y_target: np.transpose([y_vals_train])})
  train_loss.append(temp_train_loss)

  temp_test_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_test]), y_target: np.transpose([y_vals_test])})
  test_loss.append(temp_test_loss)
  if (i+1)%50==0:
    print('-----------')
    print('Generation: ' + str(i+1))
    print('A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Train Loss = ' + str(temp_train_loss))
    print('Test Loss = ' + str(temp_test_loss))

# Extract Coefficients
[[slope]] = sess.run(A)
[[y_intercept]] = sess.run(b)
[width] = sess.run(epsilon)

# Get best fit line
best_fit = []
best_fit_upper = []
best_fit_lower = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)
 best_fit_upper.append(slope*i+y_intercept+width)
 best_fit_lower.append(slope*i+y_intercept-width)

# Plot fit with data
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='SVM Regression Line', linewidth=3)
plt.plot(x_vals, best_fit_upper, 'r--', linewidth=2)
plt.plot(x_vals, best_fit_lower, 'r--', linewidth=2)
plt.ylim([0, 10])
plt.legend(loc='lower right')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(train_loss, 'k-', label='Train Set Loss')
plt.plot(test_loss, 'r--', label='Test Set Loss')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.legend(loc='upper right')
plt.show()

输出结果:

-----------
Generation: 50
A = [[ 2.91328382]] b = [[ 1.18453276]]
Train Loss = 1.17104
Test Loss = 1.1143
-----------
Generation: 100
A = [[ 2.42788291]] b = [[ 2.3755331]]
Train Loss = 0.703519
Test Loss = 0.715295
-----------
Generation: 150
A = [[ 1.84078252]] b = [[ 3.40453291]]
Train Loss = 0.338596
Test Loss = 0.365562
-----------
Generation: 200
A = [[ 1.35343242]] b = [[ 4.14853334]]
Train Loss = 0.125198
Test Loss = 0.16121

 

Tensorflow使用支持向量机拟合线性回归

基于iris数据集(花萼长度和花瓣宽度)的支持向量机回归,间隔宽度为0.5

Tensorflow使用支持向量机拟合线性回归

每次迭代的支持向量机回归的损失值(训练集和测试集)

直观地讲,我们认为SVM回归算法试图把更多的数据点拟合到直线两边2ε宽度的间隔内。这时拟合的直线对于ε参数更有意义。如果选择太小的ε值,SVM回归算法在间隔宽度内不能拟合更多的数据点;如果选择太大的ε值,将有许多条直线能够在间隔宽度内拟合所有的数据点。作者更倾向于选取更小的ε值,因为在间隔宽度附近的数据点比远处的数据点贡献更少的损失。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3使用urllib模块制作网络爬虫
Apr 08 Python
Python爬取京东的商品分类与链接
Aug 26 Python
Python 爬虫学习笔记之多线程爬虫
Sep 21 Python
Python3 循环语句(for、while、break、range等)
Nov 20 Python
django用户注册、登录、注销和用户扩展的示例
Mar 19 Python
python方法生成txt标签文件的实例代码
May 10 Python
python抽取指定url页面的title方法
May 11 Python
Python使用sort和class实现的多级排序功能示例
Aug 15 Python
python实现飞机大战微信小游戏
Mar 21 Python
python使用if语句实现一个猜拳游戏详解
Aug 27 Python
opencv python图像梯度实例详解
Feb 04 Python
python实现自动清理重复文件
Aug 24 Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
You might like
全国FM电台频率大全 - 1 北京市
2020/03/11 无线电
php 日期和时间的处理-郑阿奇(续)
2011/07/04 PHP
对象失去焦点时自己动提交数据的实现代码
2012/11/06 PHP
8个必备的PHP功能实例代码
2013/10/27 PHP
php实现redis数据库指定库号迁移的方法
2015/01/14 PHP
Zend Framework动作助手Redirector用法实例详解
2016/03/05 PHP
Yii2增加验证码步骤详解
2016/04/25 PHP
PHP使用PDO访问oracle数据库的步骤详解
2017/09/29 PHP
javascript语句中的CDATA标签的意义
2007/05/09 Javascript
一个选择最快的服务器转向代码
2009/04/27 Javascript
15 个 JavaScript Web UI 库
2010/05/19 Javascript
ExtJs 表单提交登陆实现代码
2010/08/19 Javascript
JavaScript面向对象(极简主义法minimalist approach)
2012/07/17 Javascript
js计算字符串长度包含的中文是utf8格式
2013/10/15 Javascript
Jquery 数组操作大全个人总结
2013/11/13 Javascript
javascript中数组的多种定义方法和常用函数简介
2014/05/09 Javascript
jQuery+CSS3实现四种应用广泛的导航条制作实例详解
2016/09/17 Javascript
十大热门的JavaScript框架和库
2017/03/21 Javascript
Vue-router路由判断页面未登录跳转到登录页面的实例
2017/10/26 Javascript
Jquery实现无缝向上循环滚动列表的特效
2019/02/13 jQuery
[01:15:44]首部DOTA2纪录片今日23时全网上映
2014/03/19 DOTA
[36:02]DOTA2上海特级锦标赛D组小组赛#2 Liquid VS VP第一局
2016/02/28 DOTA
深入理解Python中的内置常量
2017/05/20 Python
解决Pycharm中import时无法识别自己写的程序方法
2018/05/18 Python
Python字符串、整数、和浮点型数相互转换实例
2018/08/04 Python
详解Django中类视图使用装饰器的方式
2018/08/12 Python
Django接收自定义http header过程详解
2019/08/23 Python
基于python实现微信好友数据分析(简单)
2020/02/16 Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
2020/05/23 Python
Keras保存模型并载入模型继续训练的实现
2021/02/20 Python
Python使用paramiko连接远程服务器执行Shell命令的实现
2021/03/04 Python
CSS3的文字阴影—text-shadow的使用方法
2012/12/25 HTML / CSS
团日活动总结
2014/04/28 职场文书
2016廉政教育学习心得体会
2016/01/25 职场文书
浅谈Golang 切片(slice)扩容机制的原理
2021/06/09 Golang
golang中字符串MD5生成方式总结
2021/07/04 Golang