TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
分析Python编程时利用wxPython来支持多线程的方法
Apr 07 Python
详解使用 pyenv 管理多个版本 python 环境
Oct 19 Python
python解析html提取数据,并生成word文档实例解析
Jan 22 Python
详解Django中类视图使用装饰器的方式
Aug 12 Python
python数据结构之线性表的顺序存储结构
Sep 28 Python
Python异常处理知识点总结
Feb 18 Python
Python之time模块的时间戳,时间字符串格式化与转换方法(13位时间戳)
Aug 12 Python
关于PyTorch源码解读之torchvision.models
Aug 17 Python
postman和python mock测试过程图解
Feb 22 Python
python 如何停止一个死循环的线程
Nov 24 Python
Django url 路由匹配过程详解
Jan 22 Python
python数据抓取3种方法总结
Feb 07 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
浅谈PHP语法(1)
2006/10/09 PHP
PHP Imagick完美实现图片裁切、生成缩略图、添加水印
2016/02/22 PHP
PHP中spl_autoload_register()函数用法实例详解
2016/07/18 PHP
php断点续传之文件分割合并详解
2016/12/13 PHP
Thinkphp事务操作实例(推荐)
2017/04/01 PHP
jQuery中的ready函数与window.onload谁先执行
2016/06/21 Javascript
js自调用匿名函数的三种写法(推荐)
2016/08/19 Javascript
整理关于Bootstrap表单的慕课笔记
2017/03/29 Javascript
jQuery 改变P标签文本值方法
2018/02/24 jQuery
Vue刷新修改页面中数据的方法
2018/09/16 Javascript
JS函数动态传递参数的方法分析【基于arguments对象】
2019/06/05 Javascript
layer.open的自适应及居中及子页面标题的修改方法
2019/09/05 Javascript
JS forEach跳出循环2种实现方法
2020/06/24 Javascript
JavaScript ES 模块的使用
2020/11/12 Javascript
[48:20]OpTic vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
实例探究Python以并发方式编写高性能端口扫描器的方法
2016/06/14 Python
解决django接口无法通过ip进行访问的问题
2020/03/27 Python
python实现从ftp上下载文件的实例方法
2020/07/19 Python
python 制作网站小说下载器
2021/02/20 Python
HTML5+CSS3:3D展示商品信息示例
2017/01/03 HTML / CSS
使用phonegap查找联系人的实现方法
2017/03/31 HTML / CSS
哄娃神器4moms商店:美国婴童用品品牌
2019/03/07 全球购物
DogBuddy荷兰:找到你最完美的狗保姆
2019/04/17 全球购物
中学生获奖感言
2014/02/04 职场文书
幼儿园教师教学反思
2014/02/06 职场文书
银行贷款承诺书
2014/03/29 职场文书
购房意向书
2014/04/01 职场文书
小学生期末评语
2014/04/21 职场文书
企业员工爱岗敬业演讲稿
2014/08/26 职场文书
新闻发布会活动策划方案
2014/09/15 职场文书
委托代理人授权委托书范本
2014/09/24 职场文书
战马观后感
2015/06/08 职场文书
2016年教师新年寄语
2015/08/18 职场文书
2016入党积极分子党校培训心得体会
2016/01/06 职场文书
python树莓派通过队列实现进程交互的程序分析
2021/07/04 Python
Netty分布式客户端处理接入事件handle源码解析
2022/03/25 Java/Android