TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中easy_install 和 pip 的安装及使用
Jun 05 Python
pytorch + visdom 处理简单分类问题的示例
Jun 04 Python
对Python 窗体(tkinter)文本编辑器(Text)详解
Oct 11 Python
Python+PyQT5的子线程更新UI界面的实例
Jun 14 Python
Django实现文件上传下载
Oct 06 Python
python使用paramiko实现ssh的功能详解
Mar 06 Python
Python matplotlib绘制图形实例(包括点,曲线,注释和箭头)
Apr 17 Python
python list等分并从等分的子集中随机选取一个数
Nov 16 Python
使用Python爬虫爬取小红书完完整整的全过程
Jan 19 Python
Python页面加载的等待方式总结
Feb 28 Python
OpenCV-Python直方图均衡化实现图像去雾
Jun 07 Python
Python常用配置文件ini、json、yaml读写总结
Jul 09 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
Base64在线编码解码实现代码 演示与下载
2011/01/08 PHP
php 购物车完整实现代码
2014/06/05 PHP
PHP中ini_set与ini_get用法实例
2014/11/04 PHP
php中try catch捕获异常实例详解
2014/11/21 PHP
PHP中SESSION的注销与清除
2015/04/16 PHP
PHP版本的选择5.2.17 5.3.27 5.3.28 5.4 5.5兼容性问题分析
2016/04/04 PHP
javascript arguments 传递给函数的隐含参数
2009/08/21 Javascript
学习并汇集javascript匿名函数
2010/11/25 Javascript
javascript学习笔记(四) Number 数字类型
2012/06/19 Javascript
jquery中获取id值方法小结
2013/09/22 Javascript
表单序列化与jq中的serialize使用示例
2014/02/21 Javascript
5个数组Array方法: indexOf、filter、forEach、map、reduce使用实例
2015/01/29 Javascript
js实现圆盘记速表
2015/08/03 Javascript
jquery实现的Banner广告收缩效果代码
2015/09/02 Javascript
jQuery中选择器的基础使用教程
2016/05/23 Javascript
详解Sea.js中Module.exports和exports的区别
2017/02/12 Javascript
Angular实现下载安装包的功能代码分享
2017/09/05 Javascript
基于 Vue 的 Electron 项目搭建过程图文详解
2020/07/22 Javascript
Antd-vue Table组件添加Click事件,实现点击某行数据教程
2020/11/17 Javascript
[01:02:54]完美世界DOTA2联赛PWL S2 FTD vs GXR 第一场 11.22
2020/11/26 DOTA
python的re模块应用实例
2014/09/26 Python
python中enumerate函数用法实例分析
2015/05/20 Python
Python数组定义方法
2016/04/13 Python
python解析json串与正则匹配对比方法
2018/12/20 Python
python生成每日报表数据(Excel)并邮件发送的实例
2019/02/03 Python
python3中datetime库,time库以及pandas中的时间函数区别与详解
2020/04/16 Python
python实现扫雷游戏的示例
2020/10/20 Python
HTML5地理定位实例
2014/10/15 HTML / CSS
戴尔美国官网:Dell
2016/08/31 全球购物
印度第一网上礼品店:IGP.com
2020/02/06 全球购物
本科生求职简历的自我评价
2013/10/21 职场文书
土木工程毕业生自荐信
2013/11/12 职场文书
小学生开学第一课活动方案
2014/03/27 职场文书
护理专业自荐书
2014/06/04 职场文书
2014年内勤工作总结
2014/11/24 职场文书
2015国庆节放假通知范文
2015/07/30 职场文书