使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python类的专用方法实例分析
Jan 09 Python
Python迭代器和生成器介绍
Mar 06 Python
Python的Tornado框架实现图片上传及图片大小修改功能
Jun 30 Python
OpenCV-Python实现轮廓检测实例分析
Jan 05 Python
python实现自动发送邮件
Jun 20 Python
Python动态生成多维数组的方法示例
Aug 09 Python
python实现动态创建类的方法分析
Jun 25 Python
利用python numpy+matplotlib绘制股票k线图的方法
Jun 26 Python
Python Django 简单分页的实现代码解析
Aug 21 Python
Python run()函数和start()函数的比较和差别介绍
May 03 Python
tensorflow图像裁剪进行数据增强操作
Jun 30 Python
python Matplotlib数据可视化(2):详解三大容器对象与常用设置
Sep 30 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
杏林同学录(六)
2006/10/09 PHP
从零开始学YII2框架(四)扩展插件yii2-kartikgii
2014/08/20 PHP
封装ThinkPHP的一个文件上传方法实例
2014/10/31 PHP
php常用hash加密函数
2014/11/22 PHP
PHP CodeIgniter分页实例及多条件查询解决方案(推荐)
2017/05/20 PHP
ExtJS 工具栏 分页事件参数
2010/03/05 Javascript
jquery之empty()与remove()区别说明
2010/09/10 Javascript
JavaScript实现的一个计算数字步数的算法分享
2014/12/06 Javascript
JS ES6中setTimeout函数的执行上下文示例
2017/04/27 Javascript
vue2.0 父组件给子组件传递数据的方法
2018/01/15 Javascript
Vue源码解读之Component组件注册的实现
2018/08/24 Javascript
three.js实现炫酷的全景3D重力感应
2018/12/30 Javascript
jQuery选择器之基本选择器用法实例分析
2019/02/19 jQuery
详解滑动穿透(锁body)终极探索
2019/04/16 Javascript
jquery检测上传文件大小示例
2020/04/26 jQuery
elementUI同一页面展示多个Dialog的实现
2020/11/19 Javascript
[02:38]DOTA2亚洲邀请赛小组赛精彩集锦:Wings完美团击溃对手
2017/03/29 DOTA
Python函数参数类型*、**的区别
2015/04/11 Python
Python 使用requests模块发送GET和POST请求的实现代码
2016/09/21 Python
新手如何快速入门Python(菜鸟必看篇)
2017/06/10 Python
python用BeautifulSoup库简单爬虫实例分析
2018/07/30 Python
详解利用django中间件django.middleware.csrf.CsrfViewMiddleware防止csrf攻击
2018/10/09 Python
PyQt QCombobox设置行高的方法
2019/06/20 Python
在TensorFlow中实现矩阵维度扩展
2020/05/22 Python
解决python图像处理图像赋值后变为白色的问题
2020/06/04 Python
Pytest如何使用skip跳过执行测试
2020/08/13 Python
python 使用三引号时容易犯的小错误
2020/10/21 Python
Python 微信公众号文章爬取的示例代码
2020/11/30 Python
canvas生成带二维码海报的踩坑记录
2019/09/11 HTML / CSS
Canvas波浪花环的示例代码
2020/08/21 HTML / CSS
人力资源主管的岗位职责
2014/03/15 职场文书
毕业横幅标语
2014/10/08 职场文书
社区四风存在问题及整改措施
2014/10/26 职场文书
2016年植树节红领巾广播稿
2015/12/17 职场文书
小程序实现文字循环滚动动画
2021/06/14 Javascript
小程序实现侧滑删除功能
2022/06/25 Javascript