详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入浅析python定时杀进程
Jun 06 Python
更改Ubuntu默认python版本的两种方法python-> Anaconda
Dec 18 Python
Python实现曲线点抽稀算法的示例
Oct 12 Python
Python实现PS滤镜Fish lens图像扭曲效果示例
Jan 29 Python
python sys,os,time模块的使用(包括时间格式的各种转换)
Apr 27 Python
解决pycharm运行出错,代码正确结果不显示的问题
Nov 30 Python
Python Django 命名空间模式的实现
Aug 09 Python
Python输出指定字符串的方法
Feb 06 Python
解决Python import docx出错DLL load failed的问题
Feb 13 Python
将pytorch转成longtensor的简单方法
Feb 18 Python
keras 回调函数Callbacks 断点ModelCheckpoint教程
Jun 18 Python
PyTorch device与cuda.device用法
Apr 03 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
thinkPHP的Html模板标签使用方法
2012/11/13 PHP
9个经典的PHP代码片段分享
2014/12/18 PHP
php通过正则表达式记取数据来读取xml的方法
2015/03/09 PHP
ThinkPHP3.2.3框架Memcache缓存使用方法实例总结
2019/04/15 PHP
php生成随机数/生成随机字符串的方法小结【5种方法】
2020/05/27 PHP
javascript 动态调整图片尺寸实现代码
2009/12/28 Javascript
JS链式调用的实现方法
2013/03/07 Javascript
枚举的实现求得1-1000所有出现1的数字并计算出现1的个数
2013/09/10 Javascript
jQuery标签替换函数replaceWith()的使用例子
2014/08/28 Javascript
JavaScript的各种常见函数定义方法
2014/09/16 Javascript
20条学习javascript的编程规范的建议
2014/11/28 Javascript
JavaScript获取Url里的参数
2014/12/18 Javascript
js实现分割上传大文件
2016/03/09 Javascript
详解jQuery中的empty、remove和detach
2016/04/11 Javascript
JS使用单链表统计英语单词出现次数
2016/06/16 Javascript
picLazyLoad 实现图片延时加载(包含背景图片)
2016/07/21 Javascript
JS两种类型的表单提交方法实例分析
2016/11/28 Javascript
JS ES6多行字符串与连接字符串的表示方法
2017/04/26 Javascript
w3c编程挑战_初级脚本算法实战篇
2017/06/23 Javascript
浅析vue-router jquery和params传参(接收参数)$router $route的区别
2018/08/03 jQuery
[02:44]DOTA2英雄基础教程 克林克兹
2014/01/15 DOTA
Python3之文件读写操作的实例讲解
2018/01/23 Python
使用python批量读取word文档并整理关键信息到excel表格的实例
2018/11/07 Python
sklearn-SVC实现与类参数详解
2019/12/10 Python
python如何调用字典的key
2020/05/25 Python
英国No.1文具和办公用品在线:Euroffice
2016/09/21 全球购物
俄罗斯珠宝市场的领导者之一:Бронницкий ювелир
2019/10/02 全球购物
Linux Interview Questions For software testers
2013/05/17 面试题
银行领导证婚词
2014/01/11 职场文书
金融管理专业毕业生求职信
2014/03/12 职场文书
应届大学生求职信
2014/07/20 职场文书
2014年医院工作总结
2014/11/20 职场文书
2015年村级财务管理制度
2015/08/04 职场文书
工伤事故赔偿协议书
2015/08/06 职场文书
2016年庆“七一”主题党日活动总结
2016/04/05 职场文书
python playwright之元素定位示例详解
2022/07/23 Python