TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单FTP上传下载文件实例
Jun 30 Python
详解python实现线程安全的单例模式
Mar 05 Python
Python面向对象之类的内置attr属性示例
Dec 14 Python
如何使用python进行pdf文件分割
Nov 11 Python
python实现高斯投影正反算方式
Jan 17 Python
Python3 元组tuple入门基础
Feb 09 Python
django-crontab实现服务端的定时任务的示例代码
Feb 17 Python
python实现udp聊天窗口
Mar 31 Python
python except异常处理之后不退出,解决异常继续执行的实现
Apr 25 Python
tensorflow2.0的函数签名与图结构(推荐)
Apr 28 Python
python转化excel数字日期为标准日期操作
Jul 14 Python
python RSA加密的示例
Dec 09 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
DC《小丑》11项提名领跑奥斯卡 Netflix成第92届奥斯卡提名最大赢家
2020/04/09 欧美动漫
深入PHP curl参数的详解
2013/06/17 PHP
WampServer下安装多个版本的PHP、mysql、apache图文教程
2015/01/07 PHP
新浪SAE搭建PHP项目教程
2015/01/28 PHP
php去除头尾空格的2种方法
2015/03/16 PHP
php微信公众号开发之答题连闯三关
2018/10/20 PHP
通过js脚本复制网页上的一个表格的不错实现方法
2006/12/29 Javascript
JavaScript 基础知识 被自己遗忘的
2009/10/15 Javascript
关于javascript event flow 的一个bug详解
2013/09/17 Javascript
基于JavaScript实现仿京东图片轮播效果
2015/11/06 Javascript
jqGrid 学习笔记整理——进阶篇(一 )
2016/04/17 Javascript
javascript函数中的3个高级技巧
2016/09/22 Javascript
微信小程序 欢迎界面开发的实例详解
2016/11/30 Javascript
React-Native做一个文本输入框组件的实现代码
2017/08/10 Javascript
轻松玩转BootstrapTable(后端使用SpringMVC+Hibernate)
2017/09/06 Javascript
JavaScript中EventLoop介绍
2018/01/22 Javascript
微信小程序slider组件使用详解
2018/01/31 Javascript
利用nodeJs anywhere搭建本地服务器环境的方法
2018/05/12 NodeJs
Vue-cli配置打包文件本地使用的教程图解
2018/08/02 Javascript
javascript读取本地文件和目录方法详解
2020/08/06 Javascript
python使用pil生成图片验证码的方法
2015/05/08 Python
Python绘制的二项分布概率图示例
2018/08/22 Python
opencv-python 提取sift特征并匹配的实例
2019/12/09 Python
python关于调用函数外的变量实例
2019/12/26 Python
python 将dicom图片转换成jpg图片的实例
2020/01/13 Python
Django之choices选项和富文本编辑器的使用详解
2020/04/01 Python
基于Keras的格式化输出Loss实现方式
2020/06/17 Python
html5时钟实现代码
2010/10/22 HTML / CSS
HTML5中判断用户是否正在浏览页面的方法
2014/05/03 HTML / CSS
贪睡宠物用品:Snoozer Pet Products
2020/02/04 全球购物
日本最大的彩色隐形眼镜销售网站:CharmColor
2020/09/09 全球购物
网络工程师的自我评价
2013/10/02 职场文书
2014两会学习心得:榜样精神伴我行
2014/03/17 职场文书
课外访万家心得体会
2014/09/03 职场文书
搬迁通知
2015/04/20 职场文书
大学生各类奖学金申请书
2019/06/24 职场文书