详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python读写excel的方法
Nov 18 Python
python对url格式解析的方法
May 13 Python
python实现将英文单词表示的数字转换成阿拉伯数字的方法
Jul 02 Python
python 使用get_argument获取url query参数
Apr 28 Python
简单谈谈Python中的json与pickle
Jul 19 Python
Flask之请求钩子的实现
Dec 23 Python
Python requests模块实例用法
Feb 11 Python
python 处理微信对账单数据的实例代码
Jul 19 Python
pygame实现成语填空游戏
Oct 29 Python
python urllib和urllib3知识点总结
Feb 08 Python
python实现简单文件读写函数
Feb 25 Python
python实现控制台输出颜色
Mar 02 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
中国广播史趣谈 — 几个历史第一次
2021/03/01 无线电
php 计算两个时间戳相隔的时间的函数(小时)
2009/12/18 PHP
PHP+Mysql+jQuery查询和列表框选择操作实例讲解
2015/10/22 PHP
yii2超好用的日期组件和时间组件
2016/05/05 PHP
PHP二维关联数组的遍历方式(实例讲解)
2017/10/18 PHP
详谈PHP中public,private,protected,abstract等关键字的用法
2017/12/31 PHP
6个常见的 PHP 安全性攻击实例和阻止方法
2020/12/16 PHP
javascript实现仿银行密码输入框效果的代码
2007/12/13 Javascript
js实现GridView单选效果自动设置交替行、选中行、鼠标移动行背景色
2010/05/27 Javascript
jQuery前台数据获取实现代码
2011/03/16 Javascript
判断多个input type=file是否有已经选择好文件的代码
2012/05/23 Javascript
JQuery实现用户名无刷新验证的小例子
2013/03/22 Javascript
JS组件Bootstrap Table布局详解
2016/05/27 Javascript
JS/jQuery判断DOM节点是否存在的简单方法
2016/11/24 Javascript
微信小程序 首页制作简单实例
2017/04/07 Javascript
基于nodejs实现微信支付功能
2017/12/20 NodeJs
基于 Immutable.js 实现撤销重做功能的实例代码
2018/03/01 Javascript
angularjs1.X 重构controller 的方法小结
2019/08/15 Javascript
解决vue-cli 打包后自定义动画未执行的问题
2019/11/12 Javascript
基于JavaScript实现简单的轮播图
2021/03/03 Javascript
Python中使用中文的方法
2011/02/19 Python
Python判断字符串是否为字母或者数字(浮点数)的多种方法
2018/08/03 Python
用Python将mysql数据导出成json的方法
2018/08/21 Python
关于Python作用域自学总结
2019/06/10 Python
对pyqt5之menu和action的使用详解
2019/06/20 Python
python 修改本地网络配置的方法
2019/08/14 Python
CSS实现圆形放大镜狙击镜效果 只有圆圈里的放大
2012/12/10 HTML / CSS
英国女士家居服网站:hush
2017/08/09 全球购物
巴西网上药店:Drogaria Araujo
2021/01/06 全球购物
小区消防演习方案
2014/02/21 职场文书
夜不归宿检讨书
2014/02/25 职场文书
大学生个人自荐信样本
2014/03/02 职场文书
《春天来了》教学反思
2014/04/07 职场文书
工作经常出错的检讨书
2014/09/13 职场文书
企业党员岗位承诺书
2015/04/27 职场文书
2019大学毕业晚会主持词
2019/06/21 职场文书