详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python抓取网页时字符集转换问题处理方案分享
Jun 19 Python
python中将函数赋值给变量时需要注意的一些问题
Aug 18 Python
pycharm下打开、执行并调试scrapy爬虫程序的方法
Nov 29 Python
Django REST为文件属性输出完整URL的方法
Dec 18 Python
详解python函数传参是传值还是传引用
Jan 16 Python
python生成器与迭代器详解
Jan 01 Python
python儿童学游戏编程知识点总结
Jun 03 Python
scrapy数据存储在mysql数据库的两种方式(同步和异步)
Feb 18 Python
Python基于yield遍历多个可迭代对象
Mar 12 Python
什么是Python中的匿名函数
Jun 02 Python
Python基于smtplib协议实现发送邮件
Jun 03 Python
Pytorch使用shuffle打乱数据的操作
May 20 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
java EJB 加密与解密原理的一个例子
2008/01/11 PHP
php后台程序与Javascript的两种交互方式
2009/10/25 PHP
基于Windows下Apache PHP5.3.1安装教程
2010/01/08 PHP
基于Snoopy的PHP近似完美获取网站编码的代码
2011/10/23 PHP
PHP使用imagick读取PDF生成png缩略图的两种方法
2014/03/20 PHP
php+ajax注册实时验证功能
2016/07/20 PHP
PHP编程获取各个时间段具体时间的方法
2017/05/26 PHP
PDO::lastInsertId讲解
2019/01/29 PHP
Laravel 5.1 框架Blade模板引擎用法实例分析
2020/01/04 PHP
基于Jquery的淡入淡出的特效基础练习
2010/12/13 Javascript
THREE.JS入门教程(4)创建粒子系统
2013/01/24 Javascript
Node.js开发教程之基于OnceIO框架实现文件上传和验证功能
2016/11/30 Javascript
超全面的JavaScript开发规范(推荐)
2017/01/21 Javascript
JavaScript原型继承_动力节点Java学院整理
2017/06/30 Javascript
BetterScroll 在移动端滚动场景的应用
2017/09/18 Javascript
Nodejs连接mysql并实现增、删、改、查操作的方法详解
2018/01/04 NodeJs
js实现同一个页面,多个enter事件绑定的示例
2018/10/10 Javascript
小程序实现左滑删除效果
2019/07/25 Javascript
如何利用JavaScript编写一个格斗小游戏
2021/01/06 Javascript
[02:47]3.19DOTA2发布会 国服成长历程回顾
2014/03/25 DOTA
Python字符串格式化
2015/06/15 Python
Python3 读、写Excel文件的操作方法
2018/10/20 Python
Python实现自定义读写分离代码实例
2019/11/16 Python
利用PyCharm操作Github(仓库新建、更新,代码回滚)
2019/12/18 Python
Pytorch实现LSTM和GRU示例
2020/01/14 Python
MxNet预训练模型到Pytorch模型的转换方式
2020/05/25 Python
一个不错的HTML5 Canvas多层点击事件监听实例
2014/04/29 HTML / CSS
北京麒麟网信息技术有限公司网络游戏测试面试题
2013/09/28 面试题
环境工程大学生自荐信
2013/10/21 职场文书
药店促销活动总结
2014/07/10 职场文书
2015年企业员工工作总结范文
2015/05/21 职场文书
党小组意见范文
2015/06/08 职场文书
餐馆开业致辞
2015/08/01 职场文书
2019学校运动会开幕词
2019/05/13 职场文书
聊聊golang中多个defer的执行顺序
2021/05/08 Golang
Pandas 数据编码的十种方法
2022/04/20 Python