kaggle+mnist实现手写字体识别


Posted in Python onJuly 26, 2018

现在的许多手写字体识别代码都是基于已有的mnist手写字体数据集进行的,而kaggle需要用到网站上给出的数据集并生成测试集的输出用于提交。这里选择keras搭建卷积网络进行识别,可以直接生成测试集的结果,最终结果识别率大概97%左右的样子。

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 6 19:07:10 2017

@author: Administrator
"""

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten 
from keras.layers import Convolution2D, MaxPooling2D 
from keras.utils import np_utils
import os
import pandas as pd
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras import backend as K
import tensorflow as tf

# 全局变量 
batch_size = 100 
nb_classes = 10 
epochs = 20
# input image dimensions 
img_rows, img_cols = 28, 28 
# number of convolutional filters to use 
nb_filters = 32 
# size of pooling area for max pooling 
pool_size = (2, 2) 
# convolution kernel size 
kernel_size = (3, 3) 

inputfile='F:/data/kaggle/mnist/train.csv'
inputfile2= 'F:/data/kaggle/mnist/test.csv'
outputfile= 'F:/data/kaggle/mnist/test_label.csv'


pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
train= pd.read_csv(os.path.basename(inputfile)) #从训练数据文件读取数据
os.chdir(pwd)

pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
test= pd.read_csv(os.path.basename(inputfile2)) #从测试数据文件读取数据
os.chdir(pwd)

x_train=train.iloc[:,1:785] #得到特征数据
y_train=train['label']
y_train = np_utils.to_categorical(y_train, 10)

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #导入数据
x_test=mnist.test.images
y_test=mnist.test.labels
# 根据不同的backend定下不同的格式 
if K.image_dim_ordering() == 'th': 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 
 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 
 input_shape = (1, img_rows, img_cols) 
 test = test.reshape(test.shape[0], 1, img_rows, img_cols) 
else: 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
 X_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
 test = test.reshape(test.shape[0], img_rows, img_cols, 1) 
 input_shape = (img_rows, img_cols, 1) 

x_train = x_train.astype('float32') 
x_test = X_test.astype('float32') 
test = test.astype('float32') 
x_train /= 255 
X_test /= 255
test/=255 
print('X_train shape:', x_train.shape) 
print(x_train.shape[0], 'train samples') 
print(x_test.shape[0], 'test samples') 
print(test.shape[0], 'testOuput samples') 

model=Sequential()#model initial
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]), 
      padding='same', 
      input_shape=input_shape)) # 卷积层1 
model.add(Activation('relu')) #激活层 
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]))) #卷积层2 
model.add(Activation('relu')) #激活层 
model.add(MaxPooling2D(pool_size=pool_size)) #池化层 
model.add(Dropout(0.25)) #神经元随机失活 
model.add(Flatten()) #拉成一维数据 
model.add(Dense(128)) #全连接层1 
model.add(Activation('relu')) #激活层 
model.add(Dropout(0.5)) #随机失活 
model.add(Dense(nb_classes)) #全连接层2 
model.add(Activation('softmax')) #Softmax评分 

#编译模型 
model.compile(loss='categorical_crossentropy', 
    optimizer='adadelta', 
    metrics=['accuracy']) 
#训练模型 

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,verbose=1) 
model.predict(x_test)
#评估模型 
score = model.evaluate(x_test, y_test, verbose=0) 
print('Test score:', score[0]) 
print('Test accuracy:', score[1]) 

y_test=model.predict(test)

sess=tf.InteractiveSession()
y_test=sess.run(tf.arg_max(y_test,1))
y_test=pd.DataFrame(y_test)
y_test.to_csv(outputfile)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作MySQL数据库的方法分享
May 29 Python
python 字符串格式化代码
Mar 17 Python
Python2.6版本中实现字典推导 PEP 274(Dict Comprehensions)
Apr 28 Python
简单介绍Python中的readline()方法的使用
May 24 Python
python中zip和unzip数据的方法
May 27 Python
在Django的视图中使用form对象的方法
Jul 18 Python
Python 调用Java实例详解
Jun 02 Python
基于DATAFRAME中元素的读取与修改方法
Jun 08 Python
解决pycharm不能自动补全第三方库的函数和属性问题
Mar 12 Python
python3中布局背景颜色代码分析
Dec 01 Python
python 将Excel转Word的示例
Mar 02 Python
Python办公自动化解决world文件批量转换
Sep 15 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 #Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
You might like
ThinkPHP自动完成中使用函数与回调方法实例
2014/11/29 PHP
php生成与读取excel文件
2016/10/14 PHP
js继承 Base类的源码解析
2008/12/30 Javascript
jQuery 图像裁剪插件Jcrop的简单使用
2009/05/22 Javascript
extjs grid设置某列背景颜色和字体颜色的方法
2010/09/03 Javascript
Dom操作之兼容技巧分享
2011/09/20 Javascript
jQuery实现动态表单验证时文本框抖动效果完整实例
2015/08/21 Javascript
JavaScript中的this到底是什么(一)
2015/12/09 Javascript
PhotoSwipe异步动态加载图片方法
2016/08/25 Javascript
vue.js实现备忘录功能的方法
2017/07/10 Javascript
React-Native 组件之 Modal的使用详解
2017/08/08 Javascript
Angular Material Icon使用详解
2018/11/07 Javascript
用Electron写个带界面的nodejs爬虫的实现方法
2019/01/29 NodeJs
js核心基础之闭包的应用实例分析
2019/05/11 Javascript
JS数组扁平化、去重、排序操作实例详解
2020/02/24 Javascript
详解JavaScript匿名函数和闭包
2020/07/10 Javascript
vue 解决mintui弹窗弹起来,底部页面滚动bug问题
2020/11/12 Javascript
[01:05:30]VP vs TNC 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
Python提取Linux内核源代码的目录结构实现方法
2016/06/24 Python
Python面向对象编程基础解析(二)
2017/10/26 Python
Python3.6简单反射操作示例
2018/06/14 Python
python实现静态服务器
2019/09/05 Python
numpy.array 操作使用简单总结
2019/11/08 Python
Python多进程编程常用方法解析
2020/03/26 Python
使用BeautifulSoup4解析XML的方法小结
2020/12/07 Python
python 日志模块logging的使用场景及示例
2021/01/04 Python
Luxplus瑞典:香水和美容护理折扣
2018/01/28 全球购物
销售行政专员职责
2014/01/03 职场文书
特色冷饮店创业计划书
2014/01/28 职场文书
售后客服工作职责
2014/06/16 职场文书
我的梦想演讲稿500字
2014/08/21 职场文书
党员学习群众路线心得体会
2014/11/04 职场文书
银行竞聘报告范文
2014/11/06 职场文书
JavaScript 数组去重详解
2021/09/15 Javascript
分析MySQL优化 index merge 后引起的死锁
2022/04/19 MySQL
JavaScript设计模式之原型模式详情
2022/06/21 Javascript