使用PyTorch实现MNIST手写体识别代码


Posted in Python onJanuary 18, 2020

实验环境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.__version__

定义超参数

BATCH_SIZE=512
EPOCHS=20 
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, 
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) 
    self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)
    self.fc1 = nn.Linear(20*10*10,500)
    self.fc2 = nn.Linear(500,10)
  def forward(self,x):
    in_size = x.size(0)
    out = self.conv1(x)
    out = F.relu(out)
    out = F.max_pool2d(out, 2, 2) 
    out = self.conv2(out)
    out = F.relu(out)
    out = out.view(in_size,-1)
    out = self.fc1(out)
    out = F.relu(out)
    out = self.fc2(out)
    out = F.log_softmax(out,dim=1)
    return out

实例化网络

model = ConvNet().to(DEVICE) # 将网络移动到gpu上
optimizer = optim.Adam(model.parameters()) # 使用Adam优化器

定义训练函数

def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if(batch_idx+1)%30 == 0: 
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

定义测试函数

def test(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
      pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

开始训练

for epoch in range(1, EPOCHS + 1):
  train(model, DEVICE, train_loader, optimizer, epoch)
  test(model, DEVICE, test_loader)

实验结果

Train Epoch: 1 [14848/60000 (25%)]	Loss: 0.375058
Train Epoch: 1 [30208/60000 (50%)]	Loss: 0.255248
Train Epoch: 1 [45568/60000 (75%)]	Loss: 0.128060

Test set: Average loss: 0.0992, Accuracy: 9690/10000 (97%)

Train Epoch: 2 [14848/60000 (25%)]	Loss: 0.093066
Train Epoch: 2 [30208/60000 (50%)]	Loss: 0.087888
Train Epoch: 2 [45568/60000 (75%)]	Loss: 0.068078

Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98%)

Train Epoch: 3 [14848/60000 (25%)]	Loss: 0.043926
Train Epoch: 3 [30208/60000 (50%)]	Loss: 0.037321
Train Epoch: 3 [45568/60000 (75%)]	Loss: 0.068404

Test set: Average loss: 0.0416, Accuracy: 9859/10000 (99%)

Train Epoch: 4 [14848/60000 (25%)]	Loss: 0.031654
Train Epoch: 4 [30208/60000 (50%)]	Loss: 0.041341
Train Epoch: 4 [45568/60000 (75%)]	Loss: 0.036493

Test set: Average loss: 0.0361, Accuracy: 9873/10000 (99%)

Train Epoch: 5 [14848/60000 (25%)]	Loss: 0.027688
Train Epoch: 5 [30208/60000 (50%)]	Loss: 0.019488
Train Epoch: 5 [45568/60000 (75%)]	Loss: 0.018023

Test set: Average loss: 0.0344, Accuracy: 9875/10000 (99%)

Train Epoch: 6 [14848/60000 (25%)]	Loss: 0.024212
Train Epoch: 6 [30208/60000 (50%)]	Loss: 0.018689
Train Epoch: 6 [45568/60000 (75%)]	Loss: 0.040412

Test set: Average loss: 0.0350, Accuracy: 9879/10000 (99%)

Train Epoch: 7 [14848/60000 (25%)]	Loss: 0.030426
Train Epoch: 7 [30208/60000 (50%)]	Loss: 0.026939
Train Epoch: 7 [45568/60000 (75%)]	Loss: 0.010722

Test set: Average loss: 0.0287, Accuracy: 9892/10000 (99%)

Train Epoch: 8 [14848/60000 (25%)]	Loss: 0.021109
Train Epoch: 8 [30208/60000 (50%)]	Loss: 0.034845
Train Epoch: 8 [45568/60000 (75%)]	Loss: 0.011223

Test set: Average loss: 0.0299, Accuracy: 9904/10000 (99%)

Train Epoch: 9 [14848/60000 (25%)]	Loss: 0.011391
Train Epoch: 9 [30208/60000 (50%)]	Loss: 0.008091
Train Epoch: 9 [45568/60000 (75%)]	Loss: 0.039870

Test set: Average loss: 0.0341, Accuracy: 9890/10000 (99%)

Train Epoch: 10 [14848/60000 (25%)]	Loss: 0.026813
Train Epoch: 10 [30208/60000 (50%)]	Loss: 0.011159
Train Epoch: 10 [45568/60000 (75%)]	Loss: 0.024884

Test set: Average loss: 0.0286, Accuracy: 9901/10000 (99%)

Train Epoch: 11 [14848/60000 (25%)]	Loss: 0.006420
Train Epoch: 11 [30208/60000 (50%)]	Loss: 0.003641
Train Epoch: 11 [45568/60000 (75%)]	Loss: 0.003402

Test set: Average loss: 0.0377, Accuracy: 9894/10000 (99%)

Train Epoch: 12 [14848/60000 (25%)]	Loss: 0.006866
Train Epoch: 12 [30208/60000 (50%)]	Loss: 0.012617
Train Epoch: 12 [45568/60000 (75%)]	Loss: 0.008548

Test set: Average loss: 0.0311, Accuracy: 9908/10000 (99%)

Train Epoch: 13 [14848/60000 (25%)]	Loss: 0.010539
Train Epoch: 13 [30208/60000 (50%)]	Loss: 0.002952
Train Epoch: 13 [45568/60000 (75%)]	Loss: 0.002313

Test set: Average loss: 0.0293, Accuracy: 9905/10000 (99%)

Train Epoch: 14 [14848/60000 (25%)]	Loss: 0.002100
Train Epoch: 14 [30208/60000 (50%)]	Loss: 0.000779
Train Epoch: 14 [45568/60000 (75%)]	Loss: 0.005952

Test set: Average loss: 0.0335, Accuracy: 9897/10000 (99%)

Train Epoch: 15 [14848/60000 (25%)]	Loss: 0.006053
Train Epoch: 15 [30208/60000 (50%)]	Loss: 0.002559
Train Epoch: 15 [45568/60000 (75%)]	Loss: 0.002555

Test set: Average loss: 0.0357, Accuracy: 9894/10000 (99%)

Train Epoch: 16 [14848/60000 (25%)]	Loss: 0.000895
Train Epoch: 16 [30208/60000 (50%)]	Loss: 0.004923
Train Epoch: 16 [45568/60000 (75%)]	Loss: 0.002339

Test set: Average loss: 0.0400, Accuracy: 9893/10000 (99%)

Train Epoch: 17 [14848/60000 (25%)]	Loss: 0.004136
Train Epoch: 17 [30208/60000 (50%)]	Loss: 0.000927
Train Epoch: 17 [45568/60000 (75%)]	Loss: 0.002084

Test set: Average loss: 0.0353, Accuracy: 9895/10000 (99%)

Train Epoch: 18 [14848/60000 (25%)]	Loss: 0.004508
Train Epoch: 18 [30208/60000 (50%)]	Loss: 0.001272
Train Epoch: 18 [45568/60000 (75%)]	Loss: 0.000543

Test set: Average loss: 0.0380, Accuracy: 9894/10000 (99%)

Train Epoch: 19 [14848/60000 (25%)]	Loss: 0.001699
Train Epoch: 19 [30208/60000 (50%)]	Loss: 0.000661
Train Epoch: 19 [45568/60000 (75%)]	Loss: 0.000275

Test set: Average loss: 0.0339, Accuracy: 9905/10000 (99%)

Train Epoch: 20 [14848/60000 (25%)]	Loss: 0.000441
Train Epoch: 20 [30208/60000 (50%)]	Loss: 0.000695
Train Epoch: 20 [45568/60000 (75%)]	Loss: 0.000467

Test set: Average loss: 0.0396, Accuracy: 9894/10000 (99%)

总结

一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。

以上这篇使用PyTorch实现MNIST手写体识别代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python help()函数用法详解
Mar 11 Python
Python 自动补全(vim)
Nov 30 Python
Python list操作用法总结
Nov 10 Python
python实现windows下文件备份脚本
May 27 Python
python实现比较文件内容异同
Jun 22 Python
Python的Tkinter点击按钮触发事件的例子
Jul 19 Python
简单了解Django应用app及分布式路由
Jul 24 Python
详解python中的数据类型和控制流
Aug 08 Python
python如何调用百度识图api
Sep 29 Python
pytorch 移动端部署之helloworld的使用
Oct 30 Python
基于 Python 实践感知器分类算法
Jan 07 Python
解决pycharm下载库时出现Failed to install package的问题
Sep 04 Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
Jan 18 #Python
You might like
php中隐形字符65279(utf-8的BOM头)问题
2014/08/16 PHP
php查询mysql大量数据造成内存不足的解决方法
2015/03/04 PHP
php生成图片验证码的实例讲解
2015/08/03 PHP
Yii2中使用asset压缩js,css文件的方法
2016/11/24 PHP
PHP实现随机数字、字母的验证码功能
2018/08/01 PHP
JS中showModalDialog 的使用解析
2013/04/17 Javascript
js处理自己不能定义二维数组的方法详解
2014/03/03 Javascript
原生js和jQuery随意改变div属性style的名称和值
2014/10/22 Javascript
avalonjs制作响应式瀑布流特效
2015/05/06 Javascript
js实现跨域的多种方法
2015/12/25 Javascript
基于jQuery实现弹出可关闭遮罩提示框实例代码
2016/07/18 Javascript
基于JavaScript实现点击页面任何位置返回
2016/08/31 Javascript
ES6新特征数字、数组、字符串
2016/10/01 Javascript
JavaScript计算值然后把值嵌入到html中的实现方法
2016/10/29 Javascript
JAVA Web实时消息后台服务器推送技术---GoEasy
2016/11/04 Javascript
原生js实现秒表计时器功能
2017/02/16 Javascript
js中的DOM模拟购物车功能
2017/03/22 Javascript
微信小程序 页面跳转如何实现传值
2017/04/05 Javascript
jQuery实现的表格前端排序功能示例
2017/09/18 jQuery
利用canvas中toDataURL()将图片转为dataURL(base64)的方法详解
2017/11/20 Javascript
node vue项目开发之前后端分离实战记录
2017/12/13 Javascript
详解小程序rich-text对富文本支持方案
2018/11/28 Javascript
angular4+百分比进度显示插件用法示例
2019/05/05 Javascript
python抓取网站的图片并下载到本地的方法
2018/05/22 Python
Python3.x爬虫下载网页图片的实例讲解
2018/05/22 Python
Python 利用pydub库操作音频文件的方法
2019/01/09 Python
pytz格式化北京时间多出6分钟问题的解决方法
2019/06/21 Python
Python OpenCV之图片缩放的实现(cv2.resize)
2019/06/28 Python
如何使用python-opencv批量生成带噪点噪线的数字验证码
2020/12/21 Python
Maison Lab荷兰:名牌Outlet购物
2018/08/10 全球购物
你所知道的集合类都有哪些?主要方法?
2012/12/31 面试题
在校大学生个人的自我评价
2014/02/13 职场文书
事业单位工作人员年度考核个人总结
2015/02/12 职场文书
Canvas跟随鼠标炫彩小球的实现
2021/04/11 Javascript
Django实现在线无水印抖音视频下载(附源码及地址)
2021/05/06 Python
Java实现简单小画板
2022/06/10 Java/Android