kaggle+mnist实现手写字体识别


Posted in Python onJuly 26, 2018

现在的许多手写字体识别代码都是基于已有的mnist手写字体数据集进行的,而kaggle需要用到网站上给出的数据集并生成测试集的输出用于提交。这里选择keras搭建卷积网络进行识别,可以直接生成测试集的结果,最终结果识别率大概97%左右的样子。

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 6 19:07:10 2017

@author: Administrator
"""

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten 
from keras.layers import Convolution2D, MaxPooling2D 
from keras.utils import np_utils
import os
import pandas as pd
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras import backend as K
import tensorflow as tf

# 全局变量 
batch_size = 100 
nb_classes = 10 
epochs = 20
# input image dimensions 
img_rows, img_cols = 28, 28 
# number of convolutional filters to use 
nb_filters = 32 
# size of pooling area for max pooling 
pool_size = (2, 2) 
# convolution kernel size 
kernel_size = (3, 3) 

inputfile='F:/data/kaggle/mnist/train.csv'
inputfile2= 'F:/data/kaggle/mnist/test.csv'
outputfile= 'F:/data/kaggle/mnist/test_label.csv'


pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
train= pd.read_csv(os.path.basename(inputfile)) #从训练数据文件读取数据
os.chdir(pwd)

pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
test= pd.read_csv(os.path.basename(inputfile2)) #从测试数据文件读取数据
os.chdir(pwd)

x_train=train.iloc[:,1:785] #得到特征数据
y_train=train['label']
y_train = np_utils.to_categorical(y_train, 10)

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #导入数据
x_test=mnist.test.images
y_test=mnist.test.labels
# 根据不同的backend定下不同的格式 
if K.image_dim_ordering() == 'th': 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 
 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 
 input_shape = (1, img_rows, img_cols) 
 test = test.reshape(test.shape[0], 1, img_rows, img_cols) 
else: 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
 X_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
 test = test.reshape(test.shape[0], img_rows, img_cols, 1) 
 input_shape = (img_rows, img_cols, 1) 

x_train = x_train.astype('float32') 
x_test = X_test.astype('float32') 
test = test.astype('float32') 
x_train /= 255 
X_test /= 255
test/=255 
print('X_train shape:', x_train.shape) 
print(x_train.shape[0], 'train samples') 
print(x_test.shape[0], 'test samples') 
print(test.shape[0], 'testOuput samples') 

model=Sequential()#model initial
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]), 
      padding='same', 
      input_shape=input_shape)) # 卷积层1 
model.add(Activation('relu')) #激活层 
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]))) #卷积层2 
model.add(Activation('relu')) #激活层 
model.add(MaxPooling2D(pool_size=pool_size)) #池化层 
model.add(Dropout(0.25)) #神经元随机失活 
model.add(Flatten()) #拉成一维数据 
model.add(Dense(128)) #全连接层1 
model.add(Activation('relu')) #激活层 
model.add(Dropout(0.5)) #随机失活 
model.add(Dense(nb_classes)) #全连接层2 
model.add(Activation('softmax')) #Softmax评分 

#编译模型 
model.compile(loss='categorical_crossentropy', 
    optimizer='adadelta', 
    metrics=['accuracy']) 
#训练模型 

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,verbose=1) 
model.predict(x_test)
#评估模型 
score = model.evaluate(x_test, y_test, verbose=0) 
print('Test score:', score[0]) 
print('Test accuracy:', score[1]) 

y_test=model.predict(test)

sess=tf.InteractiveSession()
y_test=sess.run(tf.arg_max(y_test,1))
y_test=pd.DataFrame(y_test)
y_test.to_csv(outputfile)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python异步任务队列示例
Apr 01 Python
Python抓取京东图书评论数据
Aug 31 Python
python实现udp数据报传输的方法
Sep 26 Python
python实现矩阵乘法的方法
Jun 28 Python
Python编程实现两个文件夹里文件的对比功能示例【包含内容的对比】
Jun 20 Python
pycharm+django创建一个搜索网页实例代码
Jan 24 Python
python修改list中所有元素类型的三种方法
Apr 09 Python
TensorFlow Session使用的两种方法小结
Jul 30 Python
一文秒懂python读写csv xml json文件各种骚操作
Jul 04 Python
python自动点赞功能的实现思路
Feb 26 Python
Python基于字典实现switch case函数调用
Jul 22 Python
详解利用python识别图片中的条码(pyzbar)及条码图片矫正和增强
Nov 17 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 #Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
You might like
重新封装zend_soap实现http连接安全认证的php代码
2011/01/12 PHP
PHP SPL使用方法和他的威力
2013/11/12 PHP
php实现背景图上添加圆形logo图标的方法
2016/11/17 PHP
php 运算符与表达式详细介绍
2016/11/30 PHP
使用Apache的rewrite
2021/03/09 Servers
JS通过相同的name进行表格求和代码
2013/08/18 Javascript
文本框回车提交与禁止提交示例
2013/09/27 Javascript
用javascript删除当前行,添加行(示例代码)
2013/11/25 Javascript
js中opener与parent的区别详细解析
2014/01/14 Javascript
FF(火狐)浏览器无法执行window.close()解决方案
2014/11/13 Javascript
Bootstrap源码解读导航(6)
2016/12/23 Javascript
Angularjs分页查询的实现
2017/02/24 Javascript
ES6中的箭头函数实例详解
2017/04/06 Javascript
JS实现仿UC浏览器前进后退效果的实例代码
2017/07/17 Javascript
利用Angular7开发一个Radio组件的全过程
2019/07/11 Javascript
JS随机密码生成算法
2019/09/23 Javascript
vue中监听路由参数的变化及方法
2019/12/06 Javascript
Vue v-bind动态绑定class实例方法
2020/01/15 Javascript
[01:00:06]加油DOTA_EP01_网络版
2014/08/09 DOTA
定制FileField中的上传文件名称实例
2017/08/23 Python
python通过百度地图API获取某地址的经纬度详解
2018/01/28 Python
详解python中的数据类型和控制流
2019/08/08 Python
pytorch实现用Resnet提取特征并保存为txt文件的方法
2019/08/20 Python
如何基于Python爬取隐秘的角落评论
2020/07/02 Python
python爬虫实现爬取同一个网站的多页数据的实例讲解
2021/01/18 Python
h5使用canvas画布实现手势解锁
2019/01/04 HTML / CSS
英国翻新电子产品购物网站:Tech Trade
2017/12/25 全球购物
椰子猫砂:CatSpot
2018/08/27 全球购物
制药工程专业应届生求职信
2013/09/24 职场文书
简历里的自我评价范文
2014/02/24 职场文书
丧事主持词大全
2014/04/02 职场文书
市场营销工作计划书
2014/09/15 职场文书
2015社区个人工作总结范文
2015/05/13 职场文书
2015年科普工作总结
2015/07/23 职场文书
写一个Python脚本自动爬取Bilibili小视频
2021/04/24 Python
浅谈Python数学建模之整数规划
2021/06/23 Python