详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python脚本实现查找webshell的方法
Jul 31 Python
将字典转换为DataFrame并进行频次统计的方法
Apr 08 Python
python利用requests库模拟post请求时json的使用教程
Dec 07 Python
python对列进行平移变换的方法(shift)
Jan 10 Python
Numpy之random函数使用学习
Jan 29 Python
Python使用random模块生成随机数操作实例详解
Sep 17 Python
Python 静态方法和类方法实例分析
Nov 21 Python
python 多维高斯分布数据生成方式
Dec 09 Python
JAVA SWT事件四种写法实例解析
Jun 05 Python
keras实现theano和tensorflow训练的模型相互转换
Jun 19 Python
Python 如何反方向迭代一个序列
Jul 28 Python
Python中读取文件名中的数字的实例详解
Dec 25 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
PHP中利用substr_replace将指定两位置之间的字符替换为*号
2011/01/27 PHP
写出高质量的PHP程序
2012/02/04 PHP
PHP集成百度Ueditor 1.4.3
2014/11/23 PHP
PHP获取数组长度或某个值出现次数的方法
2015/02/11 PHP
php实现插入排序
2015/03/29 PHP
PHP正则+Snoopy抓取框架实现的抓取淘宝店信誉功能实例
2017/05/17 PHP
CSS中简写属性要注意TRouBLe的顺序问题(避免踩坑)
2021/03/09 HTML / CSS
arguments对象
2006/11/20 Javascript
页面使用密码保护代码
2013/04/10 Javascript
jQuery让控件左右移动的三种实现方法
2013/09/08 Javascript
php显示当前文件所在的文件以及文件夹所有文件以树形展开
2013/12/13 Javascript
jQuery aminate方法定位到页面具体位置
2013/12/26 Javascript
关闭页面window.location事件未执行的原因及解决方法
2014/09/01 Javascript
jQuery中toggle()函数的使用实例
2015/04/17 Javascript
javascript实现点击按钮弹出一个可关闭层窗口同时网页背景变灰的方法
2015/05/13 Javascript
javascript的正则匹配方法学习
2016/02/24 Javascript
jQuery Ajax请求后台数据并在前台接收
2016/12/10 Javascript
jQuery实现上传图片前预览效果功能
2017/08/03 jQuery
jQuery事件对象的属性和方法详解
2017/09/09 jQuery
js移动端图片压缩上传功能
2020/08/18 Javascript
详解Webpack实战之构建 Electron 应用
2017/12/25 Javascript
nodejs更新package.json中的dependencies依赖到最新版本的方法
2018/10/10 NodeJs
jQuery pagination分页示例详解
2018/10/23 jQuery
django简单的前后端分离的数据传输实例 axios
2020/05/18 Javascript
Python 获取新浪微博的最新公共微博实例分享
2014/07/03 Python
简单谈谈Python中的几种常见的数据类型
2017/02/10 Python
Python 中的Selenium异常处理实例代码
2018/05/03 Python
详解python的sorted函数对字典按key排序和按value排序
2018/08/10 Python
python/sympy求解矩阵方程的方法
2018/11/08 Python
Jupyter 无法下载文件夹如何实现曲线救国
2020/04/22 Python
联想C++笔试题
2012/06/13 面试题
职业教育毕业生求职信
2013/11/09 职场文书
清洁工岗位职责
2014/01/29 职场文书
《广玉兰》教学反思
2014/04/14 职场文书
竞选班长演讲稿500字
2014/08/22 职场文书
2019年恭贺升学祝福语集锦
2019/08/15 职场文书