使用PyTorch实现MNIST手写体识别代码


Posted in Python onJanuary 18, 2020

实验环境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.__version__

定义超参数

BATCH_SIZE=512
EPOCHS=20 
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, 
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) 
    self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)
    self.fc1 = nn.Linear(20*10*10,500)
    self.fc2 = nn.Linear(500,10)
  def forward(self,x):
    in_size = x.size(0)
    out = self.conv1(x)
    out = F.relu(out)
    out = F.max_pool2d(out, 2, 2) 
    out = self.conv2(out)
    out = F.relu(out)
    out = out.view(in_size,-1)
    out = self.fc1(out)
    out = F.relu(out)
    out = self.fc2(out)
    out = F.log_softmax(out,dim=1)
    return out

实例化网络

model = ConvNet().to(DEVICE) # 将网络移动到gpu上
optimizer = optim.Adam(model.parameters()) # 使用Adam优化器

定义训练函数

def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if(batch_idx+1)%30 == 0: 
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

定义测试函数

def test(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
      pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

开始训练

for epoch in range(1, EPOCHS + 1):
  train(model, DEVICE, train_loader, optimizer, epoch)
  test(model, DEVICE, test_loader)

实验结果

Train Epoch: 1 [14848/60000 (25%)]	Loss: 0.375058
Train Epoch: 1 [30208/60000 (50%)]	Loss: 0.255248
Train Epoch: 1 [45568/60000 (75%)]	Loss: 0.128060

Test set: Average loss: 0.0992, Accuracy: 9690/10000 (97%)

Train Epoch: 2 [14848/60000 (25%)]	Loss: 0.093066
Train Epoch: 2 [30208/60000 (50%)]	Loss: 0.087888
Train Epoch: 2 [45568/60000 (75%)]	Loss: 0.068078

Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98%)

Train Epoch: 3 [14848/60000 (25%)]	Loss: 0.043926
Train Epoch: 3 [30208/60000 (50%)]	Loss: 0.037321
Train Epoch: 3 [45568/60000 (75%)]	Loss: 0.068404

Test set: Average loss: 0.0416, Accuracy: 9859/10000 (99%)

Train Epoch: 4 [14848/60000 (25%)]	Loss: 0.031654
Train Epoch: 4 [30208/60000 (50%)]	Loss: 0.041341
Train Epoch: 4 [45568/60000 (75%)]	Loss: 0.036493

Test set: Average loss: 0.0361, Accuracy: 9873/10000 (99%)

Train Epoch: 5 [14848/60000 (25%)]	Loss: 0.027688
Train Epoch: 5 [30208/60000 (50%)]	Loss: 0.019488
Train Epoch: 5 [45568/60000 (75%)]	Loss: 0.018023

Test set: Average loss: 0.0344, Accuracy: 9875/10000 (99%)

Train Epoch: 6 [14848/60000 (25%)]	Loss: 0.024212
Train Epoch: 6 [30208/60000 (50%)]	Loss: 0.018689
Train Epoch: 6 [45568/60000 (75%)]	Loss: 0.040412

Test set: Average loss: 0.0350, Accuracy: 9879/10000 (99%)

Train Epoch: 7 [14848/60000 (25%)]	Loss: 0.030426
Train Epoch: 7 [30208/60000 (50%)]	Loss: 0.026939
Train Epoch: 7 [45568/60000 (75%)]	Loss: 0.010722

Test set: Average loss: 0.0287, Accuracy: 9892/10000 (99%)

Train Epoch: 8 [14848/60000 (25%)]	Loss: 0.021109
Train Epoch: 8 [30208/60000 (50%)]	Loss: 0.034845
Train Epoch: 8 [45568/60000 (75%)]	Loss: 0.011223

Test set: Average loss: 0.0299, Accuracy: 9904/10000 (99%)

Train Epoch: 9 [14848/60000 (25%)]	Loss: 0.011391
Train Epoch: 9 [30208/60000 (50%)]	Loss: 0.008091
Train Epoch: 9 [45568/60000 (75%)]	Loss: 0.039870

Test set: Average loss: 0.0341, Accuracy: 9890/10000 (99%)

Train Epoch: 10 [14848/60000 (25%)]	Loss: 0.026813
Train Epoch: 10 [30208/60000 (50%)]	Loss: 0.011159
Train Epoch: 10 [45568/60000 (75%)]	Loss: 0.024884

Test set: Average loss: 0.0286, Accuracy: 9901/10000 (99%)

Train Epoch: 11 [14848/60000 (25%)]	Loss: 0.006420
Train Epoch: 11 [30208/60000 (50%)]	Loss: 0.003641
Train Epoch: 11 [45568/60000 (75%)]	Loss: 0.003402

Test set: Average loss: 0.0377, Accuracy: 9894/10000 (99%)

Train Epoch: 12 [14848/60000 (25%)]	Loss: 0.006866
Train Epoch: 12 [30208/60000 (50%)]	Loss: 0.012617
Train Epoch: 12 [45568/60000 (75%)]	Loss: 0.008548

Test set: Average loss: 0.0311, Accuracy: 9908/10000 (99%)

Train Epoch: 13 [14848/60000 (25%)]	Loss: 0.010539
Train Epoch: 13 [30208/60000 (50%)]	Loss: 0.002952
Train Epoch: 13 [45568/60000 (75%)]	Loss: 0.002313

Test set: Average loss: 0.0293, Accuracy: 9905/10000 (99%)

Train Epoch: 14 [14848/60000 (25%)]	Loss: 0.002100
Train Epoch: 14 [30208/60000 (50%)]	Loss: 0.000779
Train Epoch: 14 [45568/60000 (75%)]	Loss: 0.005952

Test set: Average loss: 0.0335, Accuracy: 9897/10000 (99%)

Train Epoch: 15 [14848/60000 (25%)]	Loss: 0.006053
Train Epoch: 15 [30208/60000 (50%)]	Loss: 0.002559
Train Epoch: 15 [45568/60000 (75%)]	Loss: 0.002555

Test set: Average loss: 0.0357, Accuracy: 9894/10000 (99%)

Train Epoch: 16 [14848/60000 (25%)]	Loss: 0.000895
Train Epoch: 16 [30208/60000 (50%)]	Loss: 0.004923
Train Epoch: 16 [45568/60000 (75%)]	Loss: 0.002339

Test set: Average loss: 0.0400, Accuracy: 9893/10000 (99%)

Train Epoch: 17 [14848/60000 (25%)]	Loss: 0.004136
Train Epoch: 17 [30208/60000 (50%)]	Loss: 0.000927
Train Epoch: 17 [45568/60000 (75%)]	Loss: 0.002084

Test set: Average loss: 0.0353, Accuracy: 9895/10000 (99%)

Train Epoch: 18 [14848/60000 (25%)]	Loss: 0.004508
Train Epoch: 18 [30208/60000 (50%)]	Loss: 0.001272
Train Epoch: 18 [45568/60000 (75%)]	Loss: 0.000543

Test set: Average loss: 0.0380, Accuracy: 9894/10000 (99%)

Train Epoch: 19 [14848/60000 (25%)]	Loss: 0.001699
Train Epoch: 19 [30208/60000 (50%)]	Loss: 0.000661
Train Epoch: 19 [45568/60000 (75%)]	Loss: 0.000275

Test set: Average loss: 0.0339, Accuracy: 9905/10000 (99%)

Train Epoch: 20 [14848/60000 (25%)]	Loss: 0.000441
Train Epoch: 20 [30208/60000 (50%)]	Loss: 0.000695
Train Epoch: 20 [45568/60000 (75%)]	Loss: 0.000467

Test set: Average loss: 0.0396, Accuracy: 9894/10000 (99%)

总结

一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。

以上这篇使用PyTorch实现MNIST手写体识别代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现从字符串中找出字符1的位置以及个数的方法
Aug 25 Python
开源Web应用框架Django图文教程
Mar 09 Python
Python2实现的LED大数字显示效果示例
Sep 04 Python
python3实现公众号每日定时发送日报和图片
Feb 24 Python
python实战之实现excel读取、统计、写入的示例讲解
May 02 Python
Python3实现爬虫爬取赶集网列表功能【基于request和BeautifulSoup模块】
Dec 05 Python
python中时间模块的基本使用教程
May 14 Python
python开发之anaconda以及win7下安装gensim的方法
Jul 05 Python
使用python telnetlib批量备份交换机配置的方法
Jul 25 Python
Python3 字典dictionary入门基础附实例
Feb 10 Python
python闭包、深浅拷贝、垃圾回收、with语句知识点汇总
Mar 11 Python
Pytorch之Tensor和Numpy之间的转换的实现方法
Sep 03 Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
You might like
php面向对象全攻略 (十二) 抽象方法和抽象类
2009/09/30 PHP
php的$_FILES的临时储存文件与回收机制实测过程
2013/07/12 PHP
php数组使用规则分析
2015/02/27 PHP
PHP生成唯一订单号
2015/07/05 PHP
php高性能日志系统 seaslog 的安装与使用方法分析
2020/02/29 PHP
兼容FireFox 的 js 日历 支持时间的获取
2009/03/04 Javascript
自己的js工具_Form 封装
2009/08/21 Javascript
jQuery 实现侧边浮动导航菜单效果
2014/12/26 Javascript
JavaScript制作简单分页插件
2016/09/11 Javascript
jQuery插件FusionCharts实现的Marimekko图效果示例【附demo源码】
2017/03/24 jQuery
JavaScript之排序函数_动力节点Java学院整理
2017/06/30 Javascript
在小程序中使用canvas的方法示例
2018/09/17 Javascript
基于vue2的canvas时钟倒计时组件步骤解析
2018/11/05 Javascript
详解nodejs解压版安装和配置(带有搭建前端项目脚手架)
2018/12/06 NodeJs
JavaScript数据结构与算法之基本排序算法定义与效率比较【冒泡、选择、插入排序】
2019/02/21 Javascript
layui 弹出层回调获取弹出层数据的例子
2019/09/02 Javascript
node.js文件操作系统实例详解
2019/11/05 Javascript
[01:07:11]Secret vs Newbee 2019国际邀请赛小组赛 BO2 第二场 8.15
2019/08/17 DOTA
浅谈Python单向链表的实现
2015/12/24 Python
python web框架学习笔记
2016/05/03 Python
python中defaultdict的用法详解
2017/06/07 Python
python多进程控制学习小结
2018/10/31 Python
Python3 SSH远程连接服务器的方法示例
2018/12/29 Python
python 切换root 执行命令的方法
2019/01/19 Python
pytorch  网络参数 weight bias 初始化详解
2020/06/24 Python
日本最大美瞳直送网:Morecontact(中文)
2019/04/03 全球购物
祖国在我心中演讲稿400字
2014/05/04 职场文书
装修施工安全责任书
2014/07/24 职场文书
优秀团员事迹材料2000字
2014/08/20 职场文书
优秀党员学习焦裕禄精神思想汇报范文
2014/09/10 职场文书
关于成绩下滑的自我检讨书
2014/09/20 职场文书
职工擅自离岗检讨书
2014/09/23 职场文书
校园安全广播稿范文
2014/09/25 职场文书
群众路线批评与自我批评发言稿
2014/10/16 职场文书
学习心理学的体会
2014/11/07 职场文书
使用PDF.js渲染canvas实现预览pdf的效果示例
2021/04/17 Javascript