TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 判断是否为正小数和正整数的实例
Jul 23 Python
Python探索之静态方法和类方法的区别详解
Oct 27 Python
python爬虫爬取网页表格数据
Mar 07 Python
详解Pytorch 使用Pytorch拟合多项式(多项式回归)
May 24 Python
基于腾讯云服务器部署微信小程序后台服务(Python+Django)
May 08 Python
Django框架 Pagination分页实现代码实例
Sep 04 Python
wxPython实现带颜色的进度条
Nov 19 Python
python3 webp转gif格式的实现示例
Dec 10 Python
python自动下载图片的方法示例
Mar 25 Python
使用Django xadmin 实现修改时间选择器为不可输入状态
Mar 30 Python
详解Python 函数参数的拆解
Sep 02 Python
Django migrate报错的解决方案
May 20 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
初次接触php抽象工厂模式(Elgg)
2010/03/21 PHP
PHP实现简单爬虫的方法
2015/07/29 PHP
php实现word转html的方法
2016/01/22 PHP
Mac下php 5升级到php 7的步骤详解
2017/04/26 PHP
js 对象是否存在判断
2009/07/15 Javascript
javascript 自动填写表单的实现方法
2010/04/09 Javascript
jqGrid增加时--判断开始日期与结束日期(实例解析)
2013/11/08 Javascript
javascript在子页面中函数无法调试问题解决方法
2014/01/17 Javascript
javascript面向对象快速入门实例
2015/01/13 Javascript
JavaScript实现点击文字切换登录窗口的方法
2015/05/11 Javascript
jquery实现鼠标滑过小图查看大图的方法
2015/07/20 Javascript
jQuery代码性能优化的10种方法
2016/06/21 Javascript
jQuery ajax方法传递中文时出现中文乱码的解决方法
2016/07/25 Javascript
JavaScript实战(原生range和自定义特效)简单实例
2016/08/21 Javascript
详解JavaScript权威指南之对象
2016/09/27 Javascript
jQuery日期范围选择器附源码下载
2017/05/23 jQuery
React-Native实现ListView组件之上拉刷新实例(iOS和Android通用)
2017/07/11 Javascript
使用JavaScript实现链表的数据结构的代码
2017/08/02 Javascript
JS实现点击下拉列表文本框中出现对应的网址,点击跳转按钮实现跳转
2019/11/25 Javascript
python实现得到一个给定类的虚函数
2014/09/28 Python
Python中断言Assertion的一些改进方案
2016/10/27 Python
python切片及sys.argv[]用法详解
2018/05/25 Python
Python3.4学习笔记之 idle 清屏扩展插件用法分析
2019/03/01 Python
ipython和python区别详解
2019/06/26 Python
TensorFlow——Checkpoint为模型添加检查点的实例
2020/01/21 Python
CSS3实现各种图形的示例代码
2016/10/19 HTML / CSS
测绘工程专业个人自我评价
2013/12/01 职场文书
无故旷工检讨书
2014/01/26 职场文书
企业安全生产责任书
2014/04/14 职场文书
师德师风个人反思
2014/04/28 职场文书
计算机多媒体专业自荐信
2014/07/04 职场文书
酒店爱岗敬业演讲稿
2014/09/02 职场文书
在职证明书范本(2014新版)
2014/09/25 职场文书
婚前协议书标准版
2014/10/19 职场文书
教师继续教育反思周记
2015/06/25 职场文书
Python+pyaudio实现音频控制示例详解
2022/07/23 Python