详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
PyQt5利用QPainter绘制各种图形的实例
Oct 19 Python
numpy中的delete删除数组整行和整列的实例
May 09 Python
python删除本地夹里重复文件的方法
Nov 19 Python
python使用xlsxwriter实现有向无环图到Excel的转换
Dec 12 Python
Python判断两个文件是否相同与两个文本进行相同项筛选的方法
Mar 01 Python
wxPython实现整点报时
Nov 18 Python
基于Python执行dos命令并获取输出的结果
Dec 30 Python
pyCharm 实现关闭代码检查
Jun 09 Python
Python通过yagmail实现发送邮件代码解析
Oct 27 Python
怎么解决pycharm license Acti的方法
Oct 28 Python
Python获取指定网段正在使用的IP
Dec 14 Python
python 实现的截屏工具
May 08 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
PHP制作图型计数器的例子
2006/10/09 PHP
php短址转换实现方法
2015/02/25 PHP
php基础教程
2015/08/26 PHP
Avengerls vs Newbee BO3 第二场2.18
2021/03/10 DOTA
JavaScript库 开发规则
2009/01/31 Javascript
JavaScript是否可实现多线程  深入理解JavaScript定时机制
2009/12/22 Javascript
jQuery autocomplate 自扩展插件、自动完成示例代码
2011/03/28 Javascript
javascript里模拟sleep(两种实现方式)
2013/01/25 Javascript
浅谈关于JavaScript的语言特性分析
2013/04/11 Javascript
Nodejs sublime text 3安装与配置
2014/06/19 NodeJs
javascript实现拖动元素交换位置
2015/11/29 Javascript
javascript实现禁止复制网页内容汇总
2015/12/30 Javascript
利用JS生成博文目录及CSS定制博客
2016/02/10 Javascript
老生常谈 js中this的指向
2016/06/30 Javascript
vue视频播放插件vue-video-player的具体使用方法
2019/11/08 Javascript
python通过pil为png图片填充上背景颜色的方法
2015/03/17 Python
Python脚本实现格式化css文件
2015/04/08 Python
Python中的推导式使用详解
2015/06/03 Python
Django中的CACHE_BACKEND参数和站点级Cache设置
2015/07/23 Python
VSCode下配置python调试运行环境的方法
2018/04/06 Python
pandas求两个表格不相交的集合方法
2018/12/08 Python
利用PyCharm Profile分析异步爬虫效率详解
2019/05/08 Python
详解Django配置优化方法
2019/11/18 Python
将 Ubuntu 16 和 18 上的 python 升级到最新 python3.8 的方法教程
2020/03/11 Python
Python打印不合法的文件名
2020/07/31 Python
浅谈Python描述数据结构之KMP篇
2020/09/06 Python
DRF使用simple JWT身份验证的实现
2021/01/14 Python
用html5实现语音搜索框的方法
2014/03/18 HTML / CSS
Nordgreen美国官网:在线购买极简主义斯堪的纳维亚手表
2019/07/24 全球购物
SOKOLOV官网:俄罗斯珠宝首饰品牌
2021/01/02 全球购物
李培根演讲稿
2014/05/22 职场文书
四风问题自我剖析材料
2014/10/07 职场文书
依法行政工作汇报材料
2014/10/28 职场文书
交通事故赔偿起诉书
2015/05/20 职场文书
承兑汇票延期证明
2015/06/23 职场文书
python实现简单倒计时功能
2021/04/21 Python