TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python3中使用asyncio库进行快速数据抓取的教程
Apr 02 Python
Python常用算法学习基础教程
Apr 13 Python
Python 快速实现CLI 应用程序的脚手架
Dec 05 Python
numpy数组拼接简单示例
Dec 15 Python
Python发送邮件功能示例【使用QQ邮箱】
Dec 04 Python
python和c语言的主要区别总结
Jul 07 Python
Python BeautifulSoup [解决方法] TypeError: list indices must be integers or slices, not str
Aug 07 Python
详解mac python+selenium+Chrome 简单案例
Nov 08 Python
Python3搭建http服务器的实现代码
Feb 11 Python
Python3使用腾讯云文字识别(腾讯OCR)提取图片中的文字内容实例详解
Feb 18 Python
PyTorch中torch.tensor与torch.Tensor的区别详解
May 18 Python
利用OpenCV中对图像数据进行64F和8U转换的方式
Jun 03 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
php 中include()与require()的对比
2006/10/09 PHP
PHP基于文件存储实现缓存的方法
2015/07/20 PHP
脚本吧 - 幻宇工作室用到js,超强推荐base.js
2006/12/23 Javascript
基于JQuery的简单实现折叠菜单代码
2010/09/15 Javascript
浅析javascript中function 的 length 属性
2014/05/27 Javascript
JavaScript中操作Mysql数据库实例
2015/04/02 Javascript
JS实时弹出新消息提示框并有提示音响起的实现代码
2016/04/20 Javascript
基于MVC+EasyUI的web开发框架之使用云打印控件C-Lodop打印页面或套打报关运单信息
2016/08/29 Javascript
JS监控关闭浏览器操作的实例详解
2017/09/12 Javascript
AngularJS与后端php的数据交互方法
2018/08/13 Javascript
vue项目移动端实现ip输入框问题
2019/03/19 Javascript
JS实现图片轮播效果实例详解【可自动和手动】
2019/04/04 Javascript
基于原生JS封装的Modal对话框插件的示例代码
2020/09/09 Javascript
微信小程序实现翻牌抽奖动画
2020/09/21 Javascript
[46:32]Fnatic vs OG 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
python中的reduce内建函数使用方法指南
2014/08/31 Python
python里对list中的整数求平均并排序
2014/09/12 Python
通过5个知识点轻松搞定Python的作用域
2016/09/09 Python
浅谈Python实现贪心算法与活动安排问题
2017/12/19 Python
基于python的多进程共享变量正确打开方式
2018/04/28 Python
快速解决pandas.read_csv()乱码的问题
2018/06/15 Python
python实现爬取图书封面
2018/07/05 Python
python使用epoll实现服务端的方法
2018/10/16 Python
Python Numpy 实现交换两行和两列的方法
2019/06/26 Python
pandas对dataFrame中某一个列的数据进行处理的方法
2019/07/08 Python
keras tensorflow 实现在python下多进程运行
2020/02/06 Python
css3边框_动力节点Java学院整理
2017/07/11 HTML / CSS
CSS3实现网站商品展示效果图
2020/01/18 HTML / CSS
英国哈罗德园艺:Harrod Horticultural
2020/03/31 全球购物
毕业生在校学习的自我评价分享
2013/10/08 职场文书
毕业生多媒体设计求职信
2013/10/12 职场文书
新学期红领巾广播稿
2014/01/14 职场文书
教师党员公开承诺事项
2014/05/28 职场文书
国贸专业毕业求职信
2014/06/11 职场文书
学习走群众路线心得体会
2014/11/05 职场文书
《有余数的除法》教学反思
2016/02/22 职场文书