详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
12步教你理解Python装饰器
Feb 25 Python
使用Python搭建虚拟环境的配置方法
Feb 28 Python
pytorch训练imagenet分类的方法
Jul 27 Python
详解Python 切片语法
Jun 10 Python
Python 中@property的用法详解
Jan 15 Python
PyTorch在Windows环境搭建的方法步骤
May 12 Python
Python爬虫requests库多种用法实例
May 28 Python
如何教少儿学习Python编程
Jul 10 Python
Python使用itcaht库实现微信自动收发消息功能
Jul 13 Python
改变 Python 中线程执行顺序的方法
Sep 24 Python
Python之基础函数案例详解
Aug 30 Python
Python 的 sum() Pythonic 的求和方法详细
Oct 16 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
PHP文件上传实例详解!!!
2007/01/02 PHP
php简单计算年龄的方法(周岁与虚岁)
2016/12/06 PHP
PHP封装的完整分页类示例
2018/08/21 PHP
Javascript 学习书 推荐
2009/06/13 Javascript
原生js实现查找/添加/删除/指定元素的class
2013/04/12 Javascript
关于jquery的多个选择器的使用示例
2013/10/18 Javascript
JavaScript中判断对象类型的几种方法总结
2013/11/11 Javascript
js获取表格的行数和列数的方法
2015/10/23 Javascript
javascript图片预加载完整实例
2015/12/10 Javascript
JS中的==运算: [''] == false —>true
2016/07/24 Javascript
新手vue构建单页面应用实例代码
2017/09/18 Javascript
使用html+js+css 实现页面轮播图效果(实例讲解)
2017/09/21 Javascript
JS路由跳转的简单实现代码
2017/09/21 Javascript
深入浅析Node环境和浏览器的区别
2018/08/14 Javascript
vuedraggable+element ui实现页面控件拖拽排序效果
2020/07/29 Javascript
利用Vue构造器创建Form组件的通用解决方法
2018/12/03 Javascript
微信小程序学习笔记之函数定义、页面渲染图文详解
2019/03/28 Javascript
Vue render函数实战之实现tabs选项卡组件
2019/04/22 Javascript
详解如何使用React Hooks请求数据并渲染
2020/10/18 Javascript
[03:11]2014DOTA2国际邀请赛-VG掉入败者组 独家专访357
2014/07/19 DOTA
浅谈Python中用datetime包进行对时间的一些操作
2016/06/23 Python
Python中.py文件打包成exe可执行文件详解
2017/03/22 Python
python 中pyqt5 树节点点击实现多窗口切换问题
2019/07/04 Python
学习python需要有编程基础吗
2020/06/02 Python
详解CSS3中@media的实际使用
2015/08/04 HTML / CSS
C/C++有关内存的思考题
2015/12/04 面试题
不同浏览器创建XMLHttpRequest方法有什么不同
2014/11/17 面试题
计算机专业个人求职信范例
2013/09/23 职场文书
学校运动会开幕演讲稿
2014/01/04 职场文书
元旦晚会上单位领导演讲稿
2014/01/05 职场文书
信息学院毕业生自荐信范文
2014/03/04 职场文书
机电专业求职信
2014/06/14 职场文书
合唱兴趣小组活动总结
2014/07/10 职场文书
在校大学生自我评价范文
2014/09/12 职场文书
老兵退伍感言
2015/08/03 职场文书
2019年关于小学生课外阅读情况的分析报告
2019/12/02 职场文书