详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教你学会使用Python正则表达式
Sep 07 Python
Python基础练习之用户登录实现代码分享
Nov 08 Python
Python 实现Windows开机运行某软件的方法
Oct 14 Python
Python Numpy 实现交换两行和两列的方法
Jun 26 Python
django中forms组件的使用与注意
Jul 08 Python
python实现的生成word文档功能示例
Aug 23 Python
python实现opencv+scoket网络实时图传
Mar 20 Python
在python中实现求输出1-3+5-7+9-......101的和
Apr 02 Python
Python OpenCV实现测量图片物体宽度
May 27 Python
python判断正负数方式
Jun 03 Python
如何在mac版pycharm选择python版本
Jul 21 Python
Pythonic版二分查找实现过程原理解析
Aug 11 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
Ajax实时验证用户名/邮箱等是否已经存在的代码打包
2011/12/01 PHP
PHP基础知识介绍
2013/09/17 PHP
浅析Yii2集成富文本编辑器redactor实例教程
2016/04/25 PHP
Yii2实现跨mysql数据库关联查询排序功能代码
2017/02/10 PHP
PHP里的$_GET数组介绍
2019/03/22 PHP
利用ASP发送和接收XML数据的处理方法与代码
2007/11/13 Javascript
修改好的jquery滚动字幕效果实现代码
2011/06/22 Javascript
javascript 二进制运算技巧解析
2012/11/27 Javascript
jquery简单实现鼠标经过导航条改变背景图
2013/12/17 Javascript
javascript版的in_array函数(判断数组中是否存在特定值)
2014/05/09 Javascript
Javascript实现获取及设置光标位置的方法
2015/07/21 Javascript
jquery层级选择器的实现(匹配后代元素div)
2016/09/05 Javascript
Node.js读写文件之批量替换图片的实现方法
2016/09/07 Javascript
vue使用v-if v-show页面闪烁,div闪现的解决方法
2018/10/12 Javascript
js实现移动端tab切换时下划线滑动效果
2019/09/08 Javascript
Vue实现简单计算器
2021/01/20 Vue.js
python设置windows桌面壁纸的实现代码
2013/01/28 Python
python通过zlib实现压缩与解压字符串的方法
2014/11/19 Python
Python查询阿里巴巴关键字排名的方法
2015/07/08 Python
详解JavaScript编程中的window与window.screen对象
2015/10/26 Python
python中正则的使用指南
2016/12/04 Python
Python信息抽取之乱码解决办法
2017/06/29 Python
对Tensorflow中的变量初始化函数详解
2018/07/27 Python
python 返回一个列表中第二大的数方法
2019/07/09 Python
python集合常见运算案例解析
2019/10/17 Python
解决安装新版PyQt5、PyQT5-tool后打不开并Designer.exe提示no Qt platform plugin的问题
2020/04/24 Python
django ObjectDoesNotExist 和 DoesNotExist的用法
2020/07/09 Python
python 模拟登录B站的示例代码
2020/12/15 Python
英国男士时尚网站:Dandy Fellow
2018/02/09 全球购物
写好求职应聘自荐信的三部曲
2013/09/21 职场文书
工地门卫岗位职责
2013/12/30 职场文书
《骆驼和羊》教学反思
2014/02/27 职场文书
村安全生产责任书
2014/08/25 职场文书
个人对照检查材料思想汇报
2014/09/26 职场文书
2015年公务员工作总结
2015/04/24 职场文书
「Manga Time Kirara MAX」2022年5月号封面公开
2022/03/21 日漫