tensorflow实现KNN识别MNIST


Posted in Python onMarch 12, 2018

KNN算法算是最简单的机器学习算法之一了,这个算法最大的特点是没有训练过程,是一种懒惰学习,这种结构也可以在tensorflow实现。

KNN的最核心就是距离度量方式,官方例程给出的是L1范数的例子,我这里改成了L2范数,也就是我们常说的欧几里得距离度量,另外,虽然是叫KNN,意思是选取k个最接近的元素来投票产生分类,但是这里只是用了最近的那个数据的标签作为预测值了。

__author__ = 'freedom' 
import tensorflow as tf 
import numpy as np 
 
def loadMNIST(): 
 from tensorflow.examples.tutorials.mnist import input_data 
 mnist = input_data.read_data_sets('MNIST_data',one_hot=True) 
 return mnist 
def KNN(mnist): 
 train_x,train_y = mnist.train.next_batch(5000) 
 test_x,test_y = mnist.train.next_batch(200) 
 
 xtr = tf.placeholder(tf.float32,[None,784]) 
 xte = tf.placeholder(tf.float32,[784]) 
 distance = tf.sqrt(tf.reduce_sum(tf.pow(tf.add(xtr,tf.neg(xte)),2),reduction_indices=1)) 
 
 pred = tf.argmin(distance,0) 
 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 
 right = 0 
 for i in range(200): 
  ansIndex = sess.run(pred,{xtr:train_x,xte:test_x[i,:]}) 
  print 'prediction is ',np.argmax(train_y[ansIndex]) 
  print 'true value is ',np.argmax(test_y[i]) 
  if np.argmax(test_y[i]) == np.argmax(train_y[ansIndex]): 
   right += 1.0 
 accracy = right/200.0 
 print accracy 
 
if __name__ == "__main__": 
 mnist = loadMNIST() 
 KNN(mnist)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python在校内发人人网状态(人人网看状态)
Feb 19 Python
使用Python编写一个最基础的代码解释器的要点解析
Jul 12 Python
python用模块zlib压缩与解压字符串和文件的方法
Dec 16 Python
python微信公众号之关键词自动回复
Jun 15 Python
Flask框架通过Flask_login实现用户登录功能示例
Jul 17 Python
python截取两个单词之间的内容方法
Dec 25 Python
Python3中的f-Strings增强版字符串格式化方法
Mar 04 Python
详解tensorflow2.x版本无法调用gpu的一种解决方法
May 25 Python
python数据类型强制转换实例详解
Jun 22 Python
推荐值得学习的12款python-web开发框架
Aug 10 Python
关于python3.9安装wordcloud出错的问题及解决办法
Nov 02 Python
python异步的ASGI与Fast Api实现
Jul 16 Python
Python操作MySQL模拟银行转账
Mar 12 #Python
python3 图片referer防盗链的实现方法
Mar 12 #Python
tensorflow构建BP神经网络的方法
Mar 12 #Python
Python管理Windows服务小脚本
Mar 12 #Python
python实现教务管理系统
Mar 12 #Python
python编写弹球游戏的实现代码
Mar 12 #Python
python学生管理系统代码实现
Apr 05 #Python
You might like
Smarty中的注释和截断功能介绍
2015/04/09 PHP
php比较两个字符串长度的方法
2015/07/13 PHP
一个简单的js树形菜单
2011/12/09 Javascript
js 获取屏幕各种宽高的方法(浏览器兼容)
2013/05/15 Javascript
javascript ajax的5种状态介绍
2014/08/18 Javascript
js图片实时加载提供网页打开速度
2014/09/11 Javascript
js实现九宫格图片半透明渐显特效的方法
2015/02/16 Javascript
纯JavaScript代码实现移动设备绘图解锁
2015/10/16 Javascript
Angular学习笔记之angular的$filter服务浅析
2016/11/12 Javascript
在点击div中的p时,如何阻止事件冒泡
2017/02/07 Javascript
Vue插件写、用详解(附demo)
2017/03/20 Javascript
解决webpack打包速度慢的解决办法汇总
2017/07/06 Javascript
将 vue 生成的 js 上传到七牛的实例
2017/07/28 Javascript
Node.js  REPL (交互式解释器)实例详解
2017/08/06 Javascript
AngularJs 最新验证手机号码的实例,成功测试通过
2017/11/26 Javascript
Angular6 发送手机验证码按钮倒计时效果实现方法
2019/01/08 Javascript
JavaScript两种计时器的实例讲解
2019/01/31 Javascript
JS 实现发送短信验证码的“59秒后重新发送验证短信”功能
2019/08/23 Javascript
js判断非127开头的IP地址的实例代码
2020/01/05 Javascript
python实现发送邮件功能代码
2017/12/14 Python
python opencv旋转图像(保持图像不被裁减)
2018/07/26 Python
Python实现网站表单提交和模板
2019/01/15 Python
ZABBIX3.2使用python脚本实现监控报表的方法
2019/07/02 Python
Python正则表达式急速入门(小结)
2019/12/16 Python
pytorch之ImageFolder使用详解
2020/01/06 Python
如何在Django中使用聚合的实现示例
2020/03/23 Python
python 合并多个excel中同名的sheet
2021/01/22 Python
解决HTML5手机端页面缩放的问题
2017/10/27 HTML / CSS
保加利亚服装和鞋类购物网站:Bibloo.bg
2020/11/08 全球购物
Chinti & Parker官网:奢华羊绒女装和创新针织设计
2021/01/01 全球购物
青岛海底世界导游词
2015/02/11 职场文书
Vue实现导入Excel功能步骤详解
2021/07/03 Vue.js
python自动化测试之Selenium详解
2022/03/13 Python
Golang 1.18 多模块Multi-Module工作区模式的新特性
2022/04/11 Golang
详解SQL的窗口函数
2022/04/21 Oracle
IDEA 2022 Translation 未知错误 翻译文档失败
2022/04/24 Java/Android