kaggle+mnist实现手写字体识别


Posted in Python onJuly 26, 2018

现在的许多手写字体识别代码都是基于已有的mnist手写字体数据集进行的,而kaggle需要用到网站上给出的数据集并生成测试集的输出用于提交。这里选择keras搭建卷积网络进行识别,可以直接生成测试集的结果,最终结果识别率大概97%左右的样子。

# -*- coding: utf-8 -*-
"""
Created on Tue Jun 6 19:07:10 2017

@author: Administrator
"""

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten 
from keras.layers import Convolution2D, MaxPooling2D 
from keras.utils import np_utils
import os
import pandas as pd
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from keras import backend as K
import tensorflow as tf

# 全局变量 
batch_size = 100 
nb_classes = 10 
epochs = 20
# input image dimensions 
img_rows, img_cols = 28, 28 
# number of convolutional filters to use 
nb_filters = 32 
# size of pooling area for max pooling 
pool_size = (2, 2) 
# convolution kernel size 
kernel_size = (3, 3) 

inputfile='F:/data/kaggle/mnist/train.csv'
inputfile2= 'F:/data/kaggle/mnist/test.csv'
outputfile= 'F:/data/kaggle/mnist/test_label.csv'


pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
train= pd.read_csv(os.path.basename(inputfile)) #从训练数据文件读取数据
os.chdir(pwd)

pwd = os.getcwd()
os.chdir(os.path.dirname(inputfile)) 
test= pd.read_csv(os.path.basename(inputfile2)) #从测试数据文件读取数据
os.chdir(pwd)

x_train=train.iloc[:,1:785] #得到特征数据
y_train=train['label']
y_train = np_utils.to_categorical(y_train, 10)

mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #导入数据
x_test=mnist.test.images
y_test=mnist.test.labels
# 根据不同的backend定下不同的格式 
if K.image_dim_ordering() == 'th': 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 
 x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 
 input_shape = (1, img_rows, img_cols) 
 test = test.reshape(test.shape[0], 1, img_rows, img_cols) 
else: 
 x_train=np.array(x_train)
 test=np.array(test)
 x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
 X_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
 test = test.reshape(test.shape[0], img_rows, img_cols, 1) 
 input_shape = (img_rows, img_cols, 1) 

x_train = x_train.astype('float32') 
x_test = X_test.astype('float32') 
test = test.astype('float32') 
x_train /= 255 
X_test /= 255
test/=255 
print('X_train shape:', x_train.shape) 
print(x_train.shape[0], 'train samples') 
print(x_test.shape[0], 'test samples') 
print(test.shape[0], 'testOuput samples') 

model=Sequential()#model initial
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]), 
      padding='same', 
      input_shape=input_shape)) # 卷积层1 
model.add(Activation('relu')) #激活层 
model.add(Convolution2D(nb_filters, (kernel_size[0], kernel_size[1]))) #卷积层2 
model.add(Activation('relu')) #激活层 
model.add(MaxPooling2D(pool_size=pool_size)) #池化层 
model.add(Dropout(0.25)) #神经元随机失活 
model.add(Flatten()) #拉成一维数据 
model.add(Dense(128)) #全连接层1 
model.add(Activation('relu')) #激活层 
model.add(Dropout(0.5)) #随机失活 
model.add(Dense(nb_classes)) #全连接层2 
model.add(Activation('softmax')) #Softmax评分 

#编译模型 
model.compile(loss='categorical_crossentropy', 
    optimizer='adadelta', 
    metrics=['accuracy']) 
#训练模型 

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,verbose=1) 
model.predict(x_test)
#评估模型 
score = model.evaluate(x_test, y_test, verbose=0) 
print('Test score:', score[0]) 
print('Test accuracy:', score[1]) 

y_test=model.predict(test)

sess=tf.InteractiveSession()
y_test=sess.run(tf.arg_max(y_test,1))
y_test=pd.DataFrame(y_test)
y_test.to_csv(outputfile)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的ORM框架SQLObject入门实例
Apr 28 Python
django之常用命令详解
Jun 30 Python
PyQt5每天必学之布局管理
Apr 19 Python
对Python的zip函数妙用,旋转矩阵详解
Dec 13 Python
python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法
Jun 10 Python
python实现QQ批量登录功能
Jun 19 Python
Python单元测试与测试用例简析
Nov 09 Python
利用pandas将非数值数据转换成数值的方式
Dec 18 Python
Python通过Tesseract库实现文字识别
Mar 05 Python
Android Q之气泡弹窗的实现示例
Jun 23 Python
五种Python转义表示法
Nov 27 Python
M1芯片安装python3.9.1的实现
Feb 02 Python
解决tensorflow模型参数保存和加载的问题
Jul 26 #Python
解决tensorflow1.x版本加载saver.restore目录报错的问题
Jul 26 #Python
Flask web开发处理POST请求实现(登录案例)
Jul 26 #Python
基于tensorflow加载部分层的方法
Jul 26 #Python
利用python画出折线图
Jul 26 #Python
浅谈flask源码之请求过程
Jul 26 #Python
python画折线图的程序
Jul 26 #Python
You might like
用php和MySql来与ODBC数据连接
2006/10/09 PHP
php魔术方法功能与用法实例分析
2016/10/19 PHP
php注册和登录界面的实现案例(推荐)
2016/10/24 PHP
浅谈PHP命令执行php文件需要注意的问题
2016/12/16 PHP
PHP编程获取音频文件时长的方法【基于getid3类】
2017/04/20 PHP
利用php-cli和任务计划实现刷新token功能的方法
2017/05/03 PHP
php实现微信企业付款到个人零钱功能
2018/10/09 PHP
JQUERY复选框CHECKBOX全选,取消全选
2008/08/30 Javascript
jquery 得到当前页面高度和宽度的两个函数
2010/02/21 Javascript
bgsound 背景音乐 的一些常用方法及特殊用法小结
2010/05/11 Javascript
JavaScript和ActionScript的交互实现代码
2010/08/01 Javascript
javascript获取xml节点的最大值(实现代码)
2013/12/11 Javascript
JavaScript数组各种常见用法实例分析
2015/08/04 Javascript
jquery实现很酷的网页顶部图标下拉菜单效果
2015/08/22 Javascript
前端开发之CSS原理详解
2017/03/11 Javascript
深入理解 JavaScript 中的 JSON
2017/04/06 Javascript
JS实现搜索关键词的智能提示功能
2017/07/07 Javascript
node通过express搭建自己的服务器
2017/09/30 Javascript
Vue中 v-if 和v-else-if页面加载出现闪现的问题及解决方法
2018/10/12 Javascript
vue v-for直接循环数字实例
2019/11/07 Javascript
10个Python小技巧你值得拥有
2018/09/29 Python
Python Flask框架模板操作实例分析
2019/05/03 Python
浅谈python中统计计数的几种方法和Counter详解
2019/11/07 Python
Python之关于类变量的两种赋值区别详解
2020/03/12 Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
2020/05/26 Python
matplotlib常见函数之plt.rcParams、matshow的使用(坐标轴设置)
2021/01/05 Python
解决tensorflow模型压缩的问题_踩坑无数,总算搞定
2021/03/02 Python
html5写一个BUI折叠菜单插件的实现方法
2019/09/11 HTML / CSS
Scotch Porter官方网站:男士美容产品
2020/08/31 全球购物
建筑专业自荐信
2013/10/18 职场文书
新书发布会策划方案
2014/06/09 职场文书
交通事故一次性赔偿协议书范本
2014/11/02 职场文书
慰问信格式
2015/02/14 职场文书
结婚主持人致辞
2015/07/28 职场文书
《将心比心》教学反思
2016/02/23 职场文书
将MySQL的表数据全量导入clichhouse库中
2022/03/21 MySQL