使用PyTorch实现MNIST手写体识别代码


Posted in Python onJanuary 18, 2020

实验环境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.__version__

定义超参数

BATCH_SIZE=512
EPOCHS=20 
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, 
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) 
    self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)
    self.fc1 = nn.Linear(20*10*10,500)
    self.fc2 = nn.Linear(500,10)
  def forward(self,x):
    in_size = x.size(0)
    out = self.conv1(x)
    out = F.relu(out)
    out = F.max_pool2d(out, 2, 2) 
    out = self.conv2(out)
    out = F.relu(out)
    out = out.view(in_size,-1)
    out = self.fc1(out)
    out = F.relu(out)
    out = self.fc2(out)
    out = F.log_softmax(out,dim=1)
    return out

实例化网络

model = ConvNet().to(DEVICE) # 将网络移动到gpu上
optimizer = optim.Adam(model.parameters()) # 使用Adam优化器

定义训练函数

def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if(batch_idx+1)%30 == 0: 
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

定义测试函数

def test(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
      pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

开始训练

for epoch in range(1, EPOCHS + 1):
  train(model, DEVICE, train_loader, optimizer, epoch)
  test(model, DEVICE, test_loader)

实验结果

Train Epoch: 1 [14848/60000 (25%)]	Loss: 0.375058
Train Epoch: 1 [30208/60000 (50%)]	Loss: 0.255248
Train Epoch: 1 [45568/60000 (75%)]	Loss: 0.128060

Test set: Average loss: 0.0992, Accuracy: 9690/10000 (97%)

Train Epoch: 2 [14848/60000 (25%)]	Loss: 0.093066
Train Epoch: 2 [30208/60000 (50%)]	Loss: 0.087888
Train Epoch: 2 [45568/60000 (75%)]	Loss: 0.068078

Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98%)

Train Epoch: 3 [14848/60000 (25%)]	Loss: 0.043926
Train Epoch: 3 [30208/60000 (50%)]	Loss: 0.037321
Train Epoch: 3 [45568/60000 (75%)]	Loss: 0.068404

Test set: Average loss: 0.0416, Accuracy: 9859/10000 (99%)

Train Epoch: 4 [14848/60000 (25%)]	Loss: 0.031654
Train Epoch: 4 [30208/60000 (50%)]	Loss: 0.041341
Train Epoch: 4 [45568/60000 (75%)]	Loss: 0.036493

Test set: Average loss: 0.0361, Accuracy: 9873/10000 (99%)

Train Epoch: 5 [14848/60000 (25%)]	Loss: 0.027688
Train Epoch: 5 [30208/60000 (50%)]	Loss: 0.019488
Train Epoch: 5 [45568/60000 (75%)]	Loss: 0.018023

Test set: Average loss: 0.0344, Accuracy: 9875/10000 (99%)

Train Epoch: 6 [14848/60000 (25%)]	Loss: 0.024212
Train Epoch: 6 [30208/60000 (50%)]	Loss: 0.018689
Train Epoch: 6 [45568/60000 (75%)]	Loss: 0.040412

Test set: Average loss: 0.0350, Accuracy: 9879/10000 (99%)

Train Epoch: 7 [14848/60000 (25%)]	Loss: 0.030426
Train Epoch: 7 [30208/60000 (50%)]	Loss: 0.026939
Train Epoch: 7 [45568/60000 (75%)]	Loss: 0.010722

Test set: Average loss: 0.0287, Accuracy: 9892/10000 (99%)

Train Epoch: 8 [14848/60000 (25%)]	Loss: 0.021109
Train Epoch: 8 [30208/60000 (50%)]	Loss: 0.034845
Train Epoch: 8 [45568/60000 (75%)]	Loss: 0.011223

Test set: Average loss: 0.0299, Accuracy: 9904/10000 (99%)

Train Epoch: 9 [14848/60000 (25%)]	Loss: 0.011391
Train Epoch: 9 [30208/60000 (50%)]	Loss: 0.008091
Train Epoch: 9 [45568/60000 (75%)]	Loss: 0.039870

Test set: Average loss: 0.0341, Accuracy: 9890/10000 (99%)

Train Epoch: 10 [14848/60000 (25%)]	Loss: 0.026813
Train Epoch: 10 [30208/60000 (50%)]	Loss: 0.011159
Train Epoch: 10 [45568/60000 (75%)]	Loss: 0.024884

Test set: Average loss: 0.0286, Accuracy: 9901/10000 (99%)

Train Epoch: 11 [14848/60000 (25%)]	Loss: 0.006420
Train Epoch: 11 [30208/60000 (50%)]	Loss: 0.003641
Train Epoch: 11 [45568/60000 (75%)]	Loss: 0.003402

Test set: Average loss: 0.0377, Accuracy: 9894/10000 (99%)

Train Epoch: 12 [14848/60000 (25%)]	Loss: 0.006866
Train Epoch: 12 [30208/60000 (50%)]	Loss: 0.012617
Train Epoch: 12 [45568/60000 (75%)]	Loss: 0.008548

Test set: Average loss: 0.0311, Accuracy: 9908/10000 (99%)

Train Epoch: 13 [14848/60000 (25%)]	Loss: 0.010539
Train Epoch: 13 [30208/60000 (50%)]	Loss: 0.002952
Train Epoch: 13 [45568/60000 (75%)]	Loss: 0.002313

Test set: Average loss: 0.0293, Accuracy: 9905/10000 (99%)

Train Epoch: 14 [14848/60000 (25%)]	Loss: 0.002100
Train Epoch: 14 [30208/60000 (50%)]	Loss: 0.000779
Train Epoch: 14 [45568/60000 (75%)]	Loss: 0.005952

Test set: Average loss: 0.0335, Accuracy: 9897/10000 (99%)

Train Epoch: 15 [14848/60000 (25%)]	Loss: 0.006053
Train Epoch: 15 [30208/60000 (50%)]	Loss: 0.002559
Train Epoch: 15 [45568/60000 (75%)]	Loss: 0.002555

Test set: Average loss: 0.0357, Accuracy: 9894/10000 (99%)

Train Epoch: 16 [14848/60000 (25%)]	Loss: 0.000895
Train Epoch: 16 [30208/60000 (50%)]	Loss: 0.004923
Train Epoch: 16 [45568/60000 (75%)]	Loss: 0.002339

Test set: Average loss: 0.0400, Accuracy: 9893/10000 (99%)

Train Epoch: 17 [14848/60000 (25%)]	Loss: 0.004136
Train Epoch: 17 [30208/60000 (50%)]	Loss: 0.000927
Train Epoch: 17 [45568/60000 (75%)]	Loss: 0.002084

Test set: Average loss: 0.0353, Accuracy: 9895/10000 (99%)

Train Epoch: 18 [14848/60000 (25%)]	Loss: 0.004508
Train Epoch: 18 [30208/60000 (50%)]	Loss: 0.001272
Train Epoch: 18 [45568/60000 (75%)]	Loss: 0.000543

Test set: Average loss: 0.0380, Accuracy: 9894/10000 (99%)

Train Epoch: 19 [14848/60000 (25%)]	Loss: 0.001699
Train Epoch: 19 [30208/60000 (50%)]	Loss: 0.000661
Train Epoch: 19 [45568/60000 (75%)]	Loss: 0.000275

Test set: Average loss: 0.0339, Accuracy: 9905/10000 (99%)

Train Epoch: 20 [14848/60000 (25%)]	Loss: 0.000441
Train Epoch: 20 [30208/60000 (50%)]	Loss: 0.000695
Train Epoch: 20 [45568/60000 (75%)]	Loss: 0.000467

Test set: Average loss: 0.0396, Accuracy: 9894/10000 (99%)

总结

一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。

以上这篇使用PyTorch实现MNIST手写体识别代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现bitmap数据结构详解
Feb 17 Python
跟老齐学Python之用while来循环
Oct 02 Python
python实现的简单猜数字游戏
Apr 04 Python
如何将python中的List转化成dictionary
Aug 15 Python
Python在不同目录下导入模块的实现方法
Oct 27 Python
Python脚本修改阿里云的访问控制列表的方法
Mar 08 Python
pyqt5实现登录界面的模板
May 30 Python
Opencv+Python实现图像运动模糊和高斯模糊的示例
Apr 11 Python
python中的协程深入理解
Jun 10 Python
Numpy数组array和矩阵matrix转换方法
Aug 05 Python
Python 序列化和反序列化库 MarshMallow 的用法实例代码
Feb 25 Python
Python实现8种常用抽样方法
Jun 27 Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
You might like
几个php应用技巧
2008/03/27 PHP
PHP中func_get_args(),func_get_arg(),func_num_args()的区别
2013/09/30 PHP
详解PHP导入导出CSV文件
2014/11/03 PHP
对比PHP对MySQL的缓冲查询和无缓冲查询
2016/07/01 PHP
php获得刚插入数据的id 的几种方法总结
2018/05/31 PHP
Aster vs KG BO3 第三场2.18
2021/03/10 DOTA
js的延迟执行问题分析
2014/06/23 Javascript
使用jquery制作弹出框效果
2015/04/03 Javascript
Javascript中For In语句用法实例
2015/05/14 Javascript
js显示当前日期时间和星期几
2015/10/22 Javascript
JS组件Bootstrap Table表格行拖拽效果实现代码
2020/08/27 Javascript
JS获取input file绝对路径的方法(推荐)
2016/08/02 Javascript
微信小程序模板之分页滑动栏
2017/02/10 Javascript
JavaScript异步加载问题总结
2018/02/17 Javascript
Bootstrap导航菜单点击后无法自动添加active的处理方法
2018/08/10 Javascript
微信小程序动画(Animation)的实现及执行步骤
2018/10/28 Javascript
JavaScript栈和队列相关操作与实现方法详解
2018/12/07 Javascript
iphone刘海屏页面适配方法
2019/05/07 Javascript
如何解决js函数防抖、节流出现的问题
2019/06/17 Javascript
jQuery实现图片切换效果
2020/10/19 jQuery
Python爬虫实例扒取2345天气预报
2018/03/04 Python
Python2和Python3.6环境解决共存问题
2018/11/09 Python
对于Python深浅拷贝的理解
2019/07/29 Python
tensorflow tf.train.batch之数据批量读取方式
2020/01/20 Python
Python 读取xml数据,cv2裁剪图片实例
2020/03/10 Python
pycharm设置默认的UTF-8编码模式的方法详解
2020/06/01 Python
Django中如何用xlwt生成表格的方法步骤
2021/01/31 Python
HTML5中使用postMessage实现两个网页间传递数据
2016/06/22 HTML / CSS
HTML5 canvas基本绘图之绘制线条
2016/06/27 HTML / CSS
蟋蟀的住宅教学反思
2014/04/26 职场文书
敬老模范事迹
2014/05/21 职场文书
教师先进个人材料
2014/12/17 职场文书
2015年学校政教工作总结
2015/07/20 职场文书
运动会报道稿大全
2015/07/23 职场文书
有关花店创业的计划书模板
2019/08/27 职场文书
用Python创建简易网站图文教程
2021/06/11 Python