Tensorflow使用支持向量机拟合线性回归


Posted in Python onSeptember 07, 2018

支持向量机可以用来拟合线性回归。

相同的最大间隔(maximum margin)的概念应用到线性回归拟合。代替最大化分割两类目标是,最大化分割包含大部分的数据点(x,y)。我们将用相同的iris数据集,展示用刚才的概念来进行花萼长度与花瓣宽度之间的线性拟合。

相关的损失函数类似于max(0,|yi-(Axi+b)|-ε)。ε这里,是间隔宽度的一半,这意味着如果一个数据点在该区域,则损失等于0。

# SVM Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve support vector regression. We are going
# to find the line that has the maximum margin
# which INCLUDES as many points as possible
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Pedal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# Split data into train/test sets
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function
# = max(0, abs(target - predicted) + epsilon)
# 1/2 margin width parameter = epsilon
epsilon = tf.constant([0.5])
# Margin term in loss
loss = tf.reduce_mean(tf.maximum(0., tf.subtract(tf.abs(tf.subtract(model_output, y_target)), epsilon)))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.075)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
train_loss = []
test_loss = []
for i in range(200):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = np.transpose([x_vals_train[rand_index]])
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_train_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_train]), y_target: np.transpose([y_vals_train])})
  train_loss.append(temp_train_loss)

  temp_test_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_test]), y_target: np.transpose([y_vals_test])})
  test_loss.append(temp_test_loss)
  if (i+1)%50==0:
    print('-----------')
    print('Generation: ' + str(i+1))
    print('A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Train Loss = ' + str(temp_train_loss))
    print('Test Loss = ' + str(temp_test_loss))

# Extract Coefficients
[[slope]] = sess.run(A)
[[y_intercept]] = sess.run(b)
[width] = sess.run(epsilon)

# Get best fit line
best_fit = []
best_fit_upper = []
best_fit_lower = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)
 best_fit_upper.append(slope*i+y_intercept+width)
 best_fit_lower.append(slope*i+y_intercept-width)

# Plot fit with data
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='SVM Regression Line', linewidth=3)
plt.plot(x_vals, best_fit_upper, 'r--', linewidth=2)
plt.plot(x_vals, best_fit_lower, 'r--', linewidth=2)
plt.ylim([0, 10])
plt.legend(loc='lower right')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(train_loss, 'k-', label='Train Set Loss')
plt.plot(test_loss, 'r--', label='Test Set Loss')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.legend(loc='upper right')
plt.show()

输出结果:

-----------
Generation: 50
A = [[ 2.91328382]] b = [[ 1.18453276]]
Train Loss = 1.17104
Test Loss = 1.1143
-----------
Generation: 100
A = [[ 2.42788291]] b = [[ 2.3755331]]
Train Loss = 0.703519
Test Loss = 0.715295
-----------
Generation: 150
A = [[ 1.84078252]] b = [[ 3.40453291]]
Train Loss = 0.338596
Test Loss = 0.365562
-----------
Generation: 200
A = [[ 1.35343242]] b = [[ 4.14853334]]
Train Loss = 0.125198
Test Loss = 0.16121

 

Tensorflow使用支持向量机拟合线性回归

基于iris数据集(花萼长度和花瓣宽度)的支持向量机回归,间隔宽度为0.5

Tensorflow使用支持向量机拟合线性回归

每次迭代的支持向量机回归的损失值(训练集和测试集)

直观地讲,我们认为SVM回归算法试图把更多的数据点拟合到直线两边2ε宽度的间隔内。这时拟合的直线对于ε参数更有意义。如果选择太小的ε值,SVM回归算法在间隔宽度内不能拟合更多的数据点;如果选择太大的ε值,将有许多条直线能够在间隔宽度内拟合所有的数据点。作者更倾向于选取更小的ε值,因为在间隔宽度附近的数据点比远处的数据点贡献更少的损失。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现搜索本地文件信息写入文件的方法
Feb 22 Python
浅谈Scrapy框架普通反爬虫机制的应对策略
Dec 28 Python
python实现外卖信息管理系统
Jan 11 Python
用Python解数独的方法示例
Oct 24 Python
python单向循环链表原理与实现方法示例
Dec 03 Python
python3中numpy函数tile的用法详解
Dec 04 Python
PYcharm 激活方法(推荐)
Mar 23 Python
记一次django内存异常排查及解决方法
Aug 07 Python
通过代码实例了解Python sys模块
Sep 14 Python
Python json解析库jsonpath原理及使用示例
Nov 25 Python
Python结合百度语音识别实现实时翻译软件的实现
Jan 18 Python
一行Python命令实现批量加水印
Apr 07 Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
You might like
php设计模式 DAO(数据访问对象模式)
2011/06/26 PHP
php基于base64解码图片与加密图片还原实例
2014/11/03 PHP
php中debug_backtrace、debug_print_backtrace和匿名函数用法实例
2014/12/01 PHP
Symfony2函数用法实例分析
2016/03/18 PHP
PHP命名空间与自动加载类详解
2018/09/04 PHP
js操作select控件的几种方法
2010/06/02 Javascript
JQuery 中几个类选择器的简单使用介绍
2013/03/14 Javascript
js特效,页面下雪的小例子
2013/06/17 Javascript
JS代码同步文本框内容的实例方法
2013/07/12 Javascript
jQuery切换网页皮肤并保存到Cookie示例代码
2014/06/16 Javascript
js原生Ajax的封装和原理详解
2017/03/11 Javascript
使用openSpeDiv方法实现Ecshop登录弹窗框效果
2017/03/13 Javascript
jQuery取得元素标签名称小结(附代码)
2017/08/16 jQuery
jquery 键盘事件的使用方法详解
2017/09/13 jQuery
关于vuejs中v-if和v-show的区别及v-show不起作用问题
2018/03/26 Javascript
Three.js中矩阵和向量的使用教程
2019/03/19 Javascript
webpack.DefinePlugin与cross-env区别详解
2020/02/23 Javascript
js+audio实现音乐播放器
2020/09/13 Javascript
[51:14]LGD vs VP 2018国际邀请赛淘汰赛BO3 第一场 8.21
2018/08/22 DOTA
用Python计算三角函数之acos()方法的使用
2015/05/15 Python
Python中取整的几种方法小结
2017/01/06 Python
简单的python后台管理程序
2017/04/13 Python
python程序中的线程操作 concurrent模块使用详解
2019/09/23 Python
python 实现方阵的对角线遍历示例
2019/11/29 Python
python属于解释型语言么
2020/06/15 Python
python opencv图像处理(素描、怀旧、光照、流年、滤镜 原理及实现)
2020/12/10 Python
英国领先的在线鱼贩:The Fish Society
2020/08/12 全球购物
学习教师法的心得体会
2014/09/03 职场文书
2014年重阳节敬老活动方案
2014/09/16 职场文书
2014教师党员自我评议总结
2014/09/19 职场文书
采购部年度工作总结
2015/08/13 职场文书
高三数学教学反思
2016/02/18 职场文书
八年级地理课件资料及考点知识分享
2019/08/30 职场文书
JS如何使用剪贴板操作Clipboard API
2021/05/17 Javascript
Redis调用Lua脚本及使用场景快速掌握
2022/03/16 Redis
Python中使用Opencv开发停车位计数器功能
2022/04/04 Python