kaggle+mnist实现手写字体识别


Posted in Python onJuly 26, 2018

现在的许多手写字体识别代码都是基于已有的mnist手写字体数据集进行的,而kaggle需要用到网站上给出的数据集并生成测试集的输出用于提交。这里选择keras搭建卷积网络进行识别,可以直接生成测试集的结果,最终结果识别率大概97%左右的样子。

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 6 19:07:10 2017

@author: Administrator
"""

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten 
from keras.layers import Convolution2D, MaxPooling2D 
from keras.utils import np_utils
import os
import pandas as pd
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras import backend as K
import tensorflow as tf

# 全局变量 
batch_size = 100 
nb_classes = 10 
epochs = 20
# input image dimensions 
img_rows, img_cols = 28, 28 
# number of convolutional filters to use 
nb_filters = 32 
# size of pooling area for max pooling 
pool_size = (2, 2) 
# convolution kernel size 
kernel_size = (3, 3) 

inputfile='F:/data/kaggle/mnist/train.csv'
inputfile2= 'F:/data/kaggle/mnist/test.csv'
outputfile= 'F:/data/kaggle/mnist/test_label.csv'


pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
train= pd.read_csv(os.path.basename(inputfile)) #从训练数据文件读取数据
os.chdir(pwd)

pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
test= pd.read_csv(os.path.basename(inputfile2)) #从测试数据文件读取数据
os.chdir(pwd)

x_train=train.iloc[:,1:785] #得到特征数据
y_train=train['label']
y_train = np_utils.to_categorical(y_train, 10)

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #导入数据
x_test=mnist.test.images
y_test=mnist.test.labels
# 根据不同的backend定下不同的格式 
if K.image_dim_ordering() == 'th': 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 
 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 
 input_shape = (1, img_rows, img_cols) 
 test = test.reshape(test.shape[0], 1, img_rows, img_cols) 
else: 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
 X_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
 test = test.reshape(test.shape[0], img_rows, img_cols, 1) 
 input_shape = (img_rows, img_cols, 1) 

x_train = x_train.astype('float32') 
x_test = X_test.astype('float32') 
test = test.astype('float32') 
x_train /= 255 
X_test /= 255
test/=255 
print('X_train shape:', x_train.shape) 
print(x_train.shape[0], 'train samples') 
print(x_test.shape[0], 'test samples') 
print(test.shape[0], 'testOuput samples') 

model=Sequential()#model initial
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]), 
      padding='same', 
      input_shape=input_shape)) # 卷积层1 
model.add(Activation('relu')) #激活层 
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]))) #卷积层2 
model.add(Activation('relu')) #激活层 
model.add(MaxPooling2D(pool_size=pool_size)) #池化层 
model.add(Dropout(0.25)) #神经元随机失活 
model.add(Flatten()) #拉成一维数据 
model.add(Dense(128)) #全连接层1 
model.add(Activation('relu')) #激活层 
model.add(Dropout(0.5)) #随机失活 
model.add(Dense(nb_classes)) #全连接层2 
model.add(Activation('softmax')) #Softmax评分 

#编译模型 
model.compile(loss='categorical_crossentropy', 
    optimizer='adadelta', 
    metrics=['accuracy']) 
#训练模型 

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,verbose=1) 
model.predict(x_test)
#评估模型 
score = model.evaluate(x_test, y_test, verbose=0) 
print('Test score:', score[0]) 
print('Test accuracy:', score[1]) 

y_test=model.predict(test)

sess=tf.InteractiveSession()
y_test=sess.run(tf.arg_max(y_test,1))
y_test=pd.DataFrame(y_test)
y_test.to_csv(outputfile)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读写ini文件示例(python读写文件)
Mar 25 Python
python处理二进制数据的方法
Jun 03 Python
Python字符串格式化
Jun 15 Python
浅谈pandas中shift和diff函数关系
Apr 08 Python
深入浅析Python中list的复制及深拷贝与浅拷贝
Sep 03 Python
Python简单过滤字母和数字的方法小结
Jan 09 Python
详解numpy矩阵的创建与数据类型
Oct 18 Python
Python利用PyPDF2库获取PDF文件总页码实例
Apr 03 Python
Spark处理数据排序问题如何避免OOM
May 21 Python
tensorflow 大于某个值为1,小于为0的实例
Jun 30 Python
Python urllib request模块发送请求实现过程解析
Dec 10 Python
python实现简单的聊天小程序
Jul 07 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 #Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
You might like
PHP syntax error, unexpected $end 错误的一种原因及解决
2008/10/25 PHP
PHP中改变图片的尺寸大小的代码
2011/07/17 PHP
编写php应用程序实现摘要式身份验证的方法详解
2013/06/08 PHP
PHP多个文件上传到服务器实例
2014/10/29 PHP
PHP中set_include_path()函数相关用法分析
2016/07/18 PHP
PHP实践教程之过滤、验证、转义与密码详解
2017/07/24 PHP
Laravel框架使用Seeder实现自动填充数据功能
2018/06/13 PHP
jQuery结合Json提交数据到Webservice,并接收从Webservice返回的Json数据
2011/02/18 Javascript
火狐textarea输入法的bug的触发及解决
2013/07/24 Javascript
可自定义速度的js图片无缝滚动示例分享
2014/01/20 Javascript
使用jquery选择器如何获取父级元素、同级元素、子元素
2014/05/14 Javascript
Jquery中find与each方法用法实例
2015/02/04 Javascript
关于JS中的apply,call,bind的深入解析
2016/04/05 Javascript
js判断主流浏览器类型和版本号的简单实现代码
2016/05/26 Javascript
jquery实现提示语淡入效果
2017/05/05 jQuery
Vue.js弹出模态框组件开发的示例代码
2017/07/26 Javascript
纯JavaScript实现实时反馈系统时间
2017/10/26 Javascript
vue绑定事件后获取绑定事件中的this方法
2018/09/15 Javascript
three.js搭建室内场景教程
2018/12/30 Javascript
使用Vue开发自己的Chrome扩展程序过程详解
2019/06/21 Javascript
[01:06:42]VP vs NewBee Supermajor 胜者组 BO3 第二场 6.5
2018/06/06 DOTA
在Lighttpd服务器中运行Django应用的方法
2015/07/22 Python
python SVM 线性分类模型的实现
2019/07/19 Python
Python Django 实现简单注册功能过程详解
2019/07/29 Python
python使用 cx_Oracle 模块进行查询操作示例
2019/11/28 Python
python爬虫学习笔记之pyquery模块基本用法详解
2020/04/09 Python
匡威西班牙官网:Converse西班牙
2019/10/01 全球购物
Windows和Linux动态库应用异同
2016/07/28 面试题
旅游与酒店管理的自我评价分享
2013/11/03 职场文书
办公室文书岗位职责
2013/12/16 职场文书
大学生职业生涯规划方案
2014/01/03 职场文书
婚礼证婚人证婚词
2014/01/13 职场文书
大学社团计划书
2014/05/01 职场文书
碧霞祠导游词
2015/02/09 职场文书
小学校长个人总结
2015/03/03 职场文书
革命电影观后感
2015/06/18 职场文书