详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中with语句的用法
Apr 15 Python
Python通过正则表达式选取callback的方法
Jul 18 Python
pandas修改DataFrame列名的方法
Apr 08 Python
python3.4.3下逐行读入txt文本并去重的方法
Apr 29 Python
利用Python在一个文件的头部插入数据的实例
May 02 Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 Python
Python操作word常见方法示例【win32com与docx模块】
Jul 17 Python
利用python和百度地图API实现数据地图标注的方法
May 13 Python
Python获取、格式化当前时间日期的方法
Feb 10 Python
解决django中form表单设置action后无法回到原页面的问题
Mar 13 Python
scrapy在python爬虫中搭建出错的解决方法
Nov 22 Python
Python可视化神器pyecharts绘制地理图表
Jul 07 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
PHP生成静态页面详解
2006/11/19 PHP
php 执行系统命令的方法
2009/07/07 PHP
php反射应用示例
2014/02/25 PHP
php命名空间学习详解
2014/02/27 PHP
使用Thinkphp框架开发移动端接口
2015/08/05 PHP
PHP简单获取网站百度搜索和搜狗搜索收录量的方法
2016/08/23 PHP
javascript encodeURI和encodeURIComponent的比较
2010/04/03 Javascript
javascript学习笔记(五)正则表达式
2011/04/08 Javascript
说明你的Javascript技术很烂的五个原因
2011/04/26 Javascript
input:checkbox多选框实现单选效果跟radio一样
2014/06/16 Javascript
javascript学习笔记(三)BOM和DOM详解
2014/09/30 Javascript
jQuery Easyui datagrid行内实现【添加】、【编辑】、【上移】、【下移】
2016/12/19 Javascript
详解react使用react-bootstrap当轮子造车
2017/08/15 Javascript
使用vue-cli脚手架工具搭建vue-webpack项目
2019/01/14 Javascript
微信小程序实现简易table表格
2020/06/19 Javascript
js实现一个页面多个倒计时的3种方法
2019/02/25 Javascript
webpack4 SplitChunks实现代码分隔详解
2019/05/23 Javascript
python读取一个目录下所有txt里面的内容方法
2018/06/23 Python
python虚拟环境迁移方法
2019/01/03 Python
在PyCharm中实现添加快捷模块
2020/02/12 Python
Django中文件上传和文件访问微项目的方法
2020/04/27 Python
Timex手表官网:美国运动休闲手表品牌
2017/01/28 全球购物
在家更换处方镜片:Lensabl
2019/05/01 全球购物
Nike墨西哥官网:Nike MX
2020/08/30 全球购物
如何用Java判断一个文件或目录是否存在
2012/11/19 面试题
力学专业毕业生自荐信
2013/11/17 职场文书
新书发布会策划方案
2014/06/09 职场文书
节约用电标语
2014/06/17 职场文书
标准大学生职业生涯规划书写作指南
2014/09/18 职场文书
“四风”问题的主要表现和危害思想汇报
2014/09/19 职场文书
玄武湖导游词
2015/02/05 职场文书
欠条样本
2015/07/03 职场文书
2016年离婚协议书范文
2016/03/18 职场文书
迎客户欢迎词三篇
2019/09/27 职场文书
基于Redis位图实现用户签到功能
2021/05/08 Redis
Python中的协程(Coroutine)操作模块(greenlet、gevent)
2022/05/30 Python