TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作MySQL数据库9个实用实例
Dec 11 Python
使用Python脚本实现批量网站存活检测遇到问题及解决方法
Oct 11 Python
Python编程判断一个正整数是否为素数的方法
Apr 14 Python
Python下实现的RSA加密/解密及签名/验证功能示例
Jul 17 Python
Python3实现的画图及加载图片动画效果示例
Jan 19 Python
python实现判断一个字符串是否是合法IP地址的示例
Jun 04 Python
Python3安装Pillow与PIL的方法
Apr 03 Python
python里 super类的工作原理详解
Jun 19 Python
python日期相关操作实例小结
Jun 24 Python
详解Python time库的使用
Oct 10 Python
python使用turtle库绘制奥运五环
Feb 24 Python
Python如何将装饰器定义为类
Jul 30 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
来自PHP.NET的入门教程
2006/10/09 PHP
php学习笔记 面向对象的构造与析构方法
2011/06/13 PHP
php防注入,表单提交值转义的实现详解
2013/06/10 PHP
ThinkPHP实现多数据库连接的解决方法
2014/07/01 PHP
Gambit vs ForZe BO3 第二场 2.13
2021/03/10 DOTA
jQuery的实现原理的模拟代码 -3 事件处理
2010/08/03 Javascript
Javascript生成json的函数代码(可以用php的json_decode解码)
2012/06/11 Javascript
jQuery横向擦除焦点图特效代码分享
2015/09/06 Javascript
AngularJS入门教程引导程序
2016/08/18 Javascript
Websocket协议详解及简单实例代码
2016/12/12 Javascript
BootStrap 页签切换失效的解决方法
2017/08/17 Javascript
select自定义小三角样式代码(实用总结)
2017/08/18 Javascript
Vue+Flask实现简单的登录验证跳转的示例代码
2018/01/13 Javascript
vue中如何实现pdf文件预览的方法
2018/07/12 Javascript
Vue项目安装插件并保存
2019/01/28 Javascript
vue webpack build资源相对路径的问题及解决方法
2020/06/04 Javascript
如何在vue中使用kindeditor富文本编辑器
2020/12/19 Vue.js
[02:09]2018DOTA2亚洲邀请赛TNC赛前采访
2018/04/04 DOTA
[51:53]完美世界DOTA2联赛循环赛 LBZS vs DM BO2第二场 11.01
2020/11/02 DOTA
详谈Python2.6和Python3.0中对除法操作的异同
2017/04/28 Python
python numpy 部分排序 寻找最大的前几个数的方法
2018/06/27 Python
python语音识别实践之百度语音API
2018/08/30 Python
在Python中通过threshold创建mask方式
2020/02/19 Python
tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
2020/04/22 Python
python实现录音功能(可随时停止录音)
2020/10/26 Python
椰子猫砂:CatSpot
2018/08/27 全球购物
英国最大的在线床超市:Bed Star
2019/01/24 全球购物
自我评价范文点评
2013/12/04 职场文书
公司端午节活动方案
2014/02/04 职场文书
结婚周年感言
2014/02/24 职场文书
2014年六一儿童节演讲稿
2014/05/23 职场文书
法学专业大学生实习自我鉴定
2014/10/05 职场文书
汽车转让协议书范本
2014/12/07 职场文书
财务个人年度总结范文
2015/02/26 职场文书
css3中2D转换之有趣的transform形变效果
2022/02/24 HTML / CSS
如何Tomcat中使用ipv6地址
2022/05/06 Servers