kaggle+mnist实现手写字体识别


Posted in Python onJuly 26, 2018

现在的许多手写字体识别代码都是基于已有的mnist手写字体数据集进行的,而kaggle需要用到网站上给出的数据集并生成测试集的输出用于提交。这里选择keras搭建卷积网络进行识别,可以直接生成测试集的结果,最终结果识别率大概97%左右的样子。

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 6 19:07:10 2017

@author: Administrator
"""

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten 
from keras.layers import Convolution2D, MaxPooling2D 
from keras.utils import np_utils
import os
import pandas as pd
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras import backend as K
import tensorflow as tf

# 全局变量 
batch_size = 100 
nb_classes = 10 
epochs = 20
# input image dimensions 
img_rows, img_cols = 28, 28 
# number of convolutional filters to use 
nb_filters = 32 
# size of pooling area for max pooling 
pool_size = (2, 2) 
# convolution kernel size 
kernel_size = (3, 3) 

inputfile='F:/data/kaggle/mnist/train.csv'
inputfile2= 'F:/data/kaggle/mnist/test.csv'
outputfile= 'F:/data/kaggle/mnist/test_label.csv'


pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
train= pd.read_csv(os.path.basename(inputfile)) #从训练数据文件读取数据
os.chdir(pwd)

pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
test= pd.read_csv(os.path.basename(inputfile2)) #从测试数据文件读取数据
os.chdir(pwd)

x_train=train.iloc[:,1:785] #得到特征数据
y_train=train['label']
y_train = np_utils.to_categorical(y_train, 10)

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #导入数据
x_test=mnist.test.images
y_test=mnist.test.labels
# 根据不同的backend定下不同的格式 
if K.image_dim_ordering() == 'th': 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 
 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 
 input_shape = (1, img_rows, img_cols) 
 test = test.reshape(test.shape[0], 1, img_rows, img_cols) 
else: 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
 X_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
 test = test.reshape(test.shape[0], img_rows, img_cols, 1) 
 input_shape = (img_rows, img_cols, 1) 

x_train = x_train.astype('float32') 
x_test = X_test.astype('float32') 
test = test.astype('float32') 
x_train /= 255 
X_test /= 255
test/=255 
print('X_train shape:', x_train.shape) 
print(x_train.shape[0], 'train samples') 
print(x_test.shape[0], 'test samples') 
print(test.shape[0], 'testOuput samples') 

model=Sequential()#model initial
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]), 
      padding='same', 
      input_shape=input_shape)) # 卷积层1 
model.add(Activation('relu')) #激活层 
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]))) #卷积层2 
model.add(Activation('relu')) #激活层 
model.add(MaxPooling2D(pool_size=pool_size)) #池化层 
model.add(Dropout(0.25)) #神经元随机失活 
model.add(Flatten()) #拉成一维数据 
model.add(Dense(128)) #全连接层1 
model.add(Activation('relu')) #激活层 
model.add(Dropout(0.5)) #随机失活 
model.add(Dense(nb_classes)) #全连接层2 
model.add(Activation('softmax')) #Softmax评分 

#编译模型 
model.compile(loss='categorical_crossentropy', 
    optimizer='adadelta', 
    metrics=['accuracy']) 
#训练模型 

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,verbose=1) 
model.predict(x_test)
#评估模型 
score = model.evaluate(x_test, y_test, verbose=0) 
print('Test score:', score[0]) 
print('Test accuracy:', score[1]) 

y_test=model.predict(test)

sess=tf.InteractiveSession()
y_test=sess.run(tf.arg_max(y_test,1))
y_test=pd.DataFrame(y_test)
y_test.to_csv(outputfile)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(一)
Jun 09 Python
python中for语句简单遍历数据的方法
May 07 Python
python分割列表(list)的方法示例
May 07 Python
Python常用时间操作总结【取得当前时间、时间函数、应用等】
May 11 Python
Python RabbitMQ消息队列实现rpc
May 30 Python
带你认识Django
Jan 15 Python
PyCharm2018 安装及破解方法实现步骤
Sep 09 Python
python入门之基础语法学习笔记
Feb 08 Python
tensorflow实现残差网络方式(mnist数据集)
May 26 Python
python 简单的调用有道翻译
Nov 25 Python
python中requests库+xpath+lxml简单使用
Apr 29 Python
基于Python实现nc批量转tif格式
Aug 14 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 #Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
You might like
php 显示指定路径下的图片
2009/10/29 PHP
Yii2创建表单(ActiveForm)方法详解
2016/07/23 PHP
关于PHP通用返回值设置方法
2017/03/31 PHP
基于jquery的跨域调用文件
2010/11/19 Javascript
让图片旋转任意角度及JQuery插件使用介绍
2013/03/20 Javascript
js 判断文件类型并控制表单提交示例代码
2013/11/14 Javascript
浅析JavaScript中的delete运算符
2013/11/30 Javascript
JavaScript搜索字符串并将搜索结果返回到字符串的方法
2015/04/06 Javascript
JavaScript创建闭包的两种方式的优劣与区别分析
2015/06/22 Javascript
jquery实现的Accordion折叠面板效果代码
2015/09/02 Javascript
字符串反转_JavaScript
2016/04/28 Javascript
如何利用JSHint减少JavaScript的错误
2016/08/23 Javascript
浅析webpack 如何优雅的使用tree-shaking(摇树优化)
2017/08/16 Javascript
Angular移动端页面input无法输入的解决方法
2017/11/14 Javascript
解决JSON.stringify()自动将中文转译成unicode的问题
2018/01/05 Javascript
React Native悬浮按钮组件的示例代码
2018/04/05 Javascript
Node.js中package.json中库的版本号(~和^)
2019/04/02 Javascript
Node.js API详解之 V8模块用法实例分析
2020/06/05 Javascript
vue实现点击出现操作弹出框的示例
2020/11/05 Javascript
[04:31]2016国际邀请赛中国区预选赛妖精采访
2016/06/27 DOTA
python实现随机漫步方法和原理
2019/06/10 Python
Python中的几种矩阵乘法(小结)
2019/07/10 Python
numpy的Fancy Indexing和array比较详解
2020/06/11 Python
Python实时监控网站浏览记录实现过程详解
2020/07/14 Python
使用Python快速打开一个百万行级别的超大Excel文件的方法
2021/03/02 Python
html5利用canvas绘画二级树形结构图的示例
2017/09/27 HTML / CSS
教师党性分析材料
2014/02/04 职场文书
制作部班长职位说明书
2014/02/26 职场文书
我的长生果教学反思
2014/04/28 职场文书
班主任开场白
2015/06/01 职场文书
2015年信息化建设工作总结
2015/07/23 职场文书
学会感恩主题班会
2015/08/12 职场文书
2019年销售人员的职业生涯规划书
2019/03/25 职场文书
门面租赁合同范文
2019/08/06 职场文书
SpringBoot SpringEL表达式的使用
2021/07/25 Java/Android
golang操作rocketmq的示例代码
2022/04/06 Golang