TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中内置数据类型list,tuple,dict,set的区别和用法
Dec 14 Python
python3.5使用tkinter制作记事本
Jun 20 Python
python使用正则表达式替换匹配成功的组
Nov 17 Python
Python 按字典dict的键排序,并取出相应的键值放于list中的实例
Feb 12 Python
对python 调用类属性的方法详解
Jul 02 Python
Python 转换RGB颜色值的示例代码
Oct 13 Python
python-sys.stdout作为默认函数参数的实现
Feb 21 Python
python闭包、深浅拷贝、垃圾回收、with语句知识点汇总
Mar 11 Python
Python如何安装第三方模块
May 28 Python
解决python 虚拟环境删除包无法加载的问题
Jul 13 Python
Python操作MySQL数据库的示例代码
Jul 13 Python
scrapy-splash简单使用详解
Feb 21 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
虫族 Zerg 魔法科技
2020/03/14 星际争霸
在任意字符集下正常显示网页的方法二(续)
2007/04/01 PHP
php图片验证码代码
2008/03/27 PHP
PHP 数字左侧自动补0
2008/03/31 PHP
PHP 动态随机生成验证码类代码
2010/04/09 PHP
UCenter中的一个可逆加密函数authcode函数代码
2010/07/20 PHP
PHP获取MSN好友列表类的实现代码
2013/06/23 PHP
PHP中把错误日志保存在系统日志中(Windows系统)
2015/06/23 PHP
thinkphp5使用无限极分类
2019/02/18 PHP
PHP中16个高危函数整理
2019/09/19 PHP
javascript offsetX与layerX区别
2010/03/12 Javascript
javascript 用原型继承来实现对象系统
2010/03/22 Javascript
JQuery实现简单时尚快捷的气泡提示插件
2012/12/20 Javascript
js复制网页内容并兼容各主流浏览器的代码
2013/12/17 Javascript
js 异步操作回调函数如何控制执行顺序
2013/12/24 Javascript
让html页面不缓存js的实现方法
2014/10/31 Javascript
关于Javascript加载执行优化的研究报告
2014/12/16 Javascript
jQuery 中的 DOM 操作
2016/04/26 Javascript
小程序scroll-view组件实现滚动的示例代码
2018/09/20 Javascript
[14:21]VICI vs EG (BO3)
2018/06/07 DOTA
详解Python对JSON中的特殊类型进行Encoder
2019/07/15 Python
python使用 __init__初始化操作简单示例
2019/09/26 Python
CSS伪类与CSS伪元素的区别及由来具体说明
2012/12/07 HTML / CSS
CSS3新增布局之: flex详解
2020/06/18 HTML / CSS
日本PLST在线商店:日本时尚杂志刊载的人气服装
2016/12/10 全球购物
希特勒经典演讲稿
2014/05/19 职场文书
护士长2014年终工作总结
2014/11/11 职场文书
怎样写离婚协议书
2015/01/26 职场文书
忠诚与背叛观后感
2015/06/04 职场文书
结婚仪式主持词
2015/06/29 职场文书
2015年小学重阳节活动总结
2015/07/29 职场文书
有关花店创业的计划书模板
2019/08/27 职场文书
MySQL中出现乱码问题的终极解决宝典
2021/05/26 MySQL
Nginx+Tomcat负载均衡集群的实现示例
2021/10/24 Servers
Python Matplotlib绘制等高线图与渐变色扇形图
2022/04/14 Python
清空 Oracle 安装记录并重新安装
2022/04/26 Oracle