TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读文件逐行处理的示例代码分享
Dec 27 Python
下载给定网页上图片的方法
Feb 18 Python
python 容器总结整理
Apr 04 Python
Python基于回溯法子集树模板解决选排问题示例
Sep 07 Python
详解python string类型 bytes类型 bytearray类型
Dec 16 Python
django在接受post请求时显示403forbidden实例解析
Jan 25 Python
python利用socketserver实现并发套接字功能
Jan 26 Python
python写入数据到csv或xlsx文件的3种方法
Aug 23 Python
Python自动化完成tb喵币任务的操作方法
Oct 30 Python
解决Jupyter notebook中.py与.ipynb文件的import问题
Apr 21 Python
python能做哪方面的工作
Jun 15 Python
Python如何加载模型并查看网络
Jul 15 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
锁定年轻人的双倍活力 星巴克推出星倍醇即饮浓咖啡
2021/03/03 咖啡文化
非常好的php目录导航文件代码
2006/10/09 PHP
PHP 图像尺寸调整代码
2010/05/26 PHP
PHP eval函数使用介绍
2013/12/08 PHP
php多进程应用场景实例详解
2019/07/22 PHP
非常好的js代码
2006/06/27 Javascript
Javascript 个人笔记(没有整理,很乱)
2007/07/07 Javascript
JS实现图片翻书效果示例代码
2013/09/09 Javascript
JavaScript中for..in循环陷阱介绍
2013/11/12 Javascript
JS在可编辑的div中的光标位置插入内容的方法
2014/11/20 Javascript
使用pjax实现无刷新更改页面url
2015/02/05 Javascript
jQuery+css3实现Ajax点击后动态删除功能的方法
2015/08/10 Javascript
jQuery三级下拉列表导航菜单代码分享
2020/04/15 Javascript
JavaScript中使用数组方法汇总
2016/02/16 Javascript
KnockoutJS 3.X API 第四章之事件event绑定
2016/10/10 Javascript
Angularjs实现分页和分页算法的示例代码
2016/12/23 Javascript
Bootstrap表格制作代码
2017/03/17 Javascript
一篇看懂vuejs的状态管理神器 vuex状态管理模式
2017/04/20 Javascript
KOA+egg.js集成kafka消息队列的示例
2018/11/09 Javascript
js的继承方法小结(prototype、call、apply)(推荐)
2019/04/17 Javascript
ionic3双击返回退出应用的方法
2019/09/17 Javascript
[04:44]DOTA2 2017全国高校联赛视频回顾
2017/08/21 DOTA
python之django母板页面的使用
2018/07/03 Python
python实现括号匹配的思路详解
2018/08/23 Python
对python 匹配字符串开头和结尾的方法详解
2018/10/27 Python
一行Python代码过滤标点符号等特殊字符
2019/08/12 Python
使用Html5、CSS实现文字阴影效果
2018/01/17 HTML / CSS
什么是重载?CTS、CLS和CLR分别做何解释
2012/05/06 面试题
介绍一下RMI的基本概念
2016/12/17 面试题
统计员岗位职责
2013/11/14 职场文书
工厂门卫的岗位职责
2014/07/27 职场文书
119消防日活动总结
2014/08/29 职场文书
关于九一八事变的演讲稿2014
2014/09/17 职场文书
恰同学少年观后感
2015/06/08 职场文书
2016年学校爱国卫生月活动总结
2016/04/06 职场文书
详解GaussDB for MySQL性能优化
2021/05/18 MySQL