详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python判断变量是否已经定义的方法
Aug 18 Python
python中__call__方法示例分析
Oct 11 Python
Python使用matplotlib绘制动画的方法
May 20 Python
浅谈利用numpy对矩阵进行归一化处理的方法
Jul 11 Python
Python3非对称加密算法RSA实例详解
Dec 06 Python
Python代码打开本地.mp4格式文件的方法
Jan 03 Python
Python之使用adb shell命令启动应用的方法详解
Jan 07 Python
python字符串查找函数的用法详解
Jul 08 Python
python实现邮件自动发送
Aug 10 Python
python爬虫爬取笔趣网小说网站过程图解
Nov 18 Python
python3环境搭建过程(利用Anaconda+pycharm)完整版
Aug 19 Python
用python写PDF转换器的实现
Oct 29 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
简单的用PHP编写的导航条程序
2006/10/09 PHP
php入门教程 精简版
2009/12/13 PHP
php实现可以设置中奖概率的抽奖程序代码分享
2014/01/19 PHP
PHP根据树的前序遍历和中序遍历构造树并输出后序遍历的方法
2017/11/10 PHP
PHP绕过open_basedir限制操作文件的方法
2018/06/10 PHP
javascript应用:Iframe自适应其加载的内容高度
2007/04/10 Javascript
javascript 显示当前系统时间代码
2009/12/28 Javascript
javascript 哈希表(hashtable)的简单实现
2010/01/20 Javascript
Three.js源码阅读笔记(Object3D类)
2012/12/27 Javascript
javascript实现控制浏览器全屏
2015/03/30 Javascript
简单讲解jQuery中的子元素过滤选择器
2016/04/18 Javascript
轻松掌握JavaScript策略模式
2016/08/25 Javascript
jQuery实现导航高亮的方法【附demo源码下载】
2016/11/09 Javascript
JS中Select下拉列表类(支持输入模糊查询)功能
2017/01/17 Javascript
JS拉起或下载app的实现代码
2017/02/22 Javascript
完美解决input[type=number]无法显示非数字字符的问题
2017/02/28 Javascript
利用pm2部署多个node.js项目的配置教程
2017/10/22 Javascript
angularjs中$http异步上传Excel文件方法
2018/02/23 Javascript
如何在postman测试用例中实现断言过程解析
2020/07/09 Javascript
解决vue加scoped后就无法修改vant的UI组件的样式问题
2020/09/07 Javascript
[48:32]VGJ.T vs Fnatic 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python获取android设备的GPS信息脚本分享
2015/03/06 Python
Python读写unicode文件的方法
2015/07/10 Python
python 3.0 模拟用户登录功能并实现三次错误锁定
2017/11/01 Python
python之pyqt5通过按钮改变Label的背景颜色方法
2019/06/13 Python
修改 CentOS 6.x 上默认Python的方法
2019/09/06 Python
.net软件工程师应聘上机试题
2015/03/10 面试题
工作岗位说明书模板
2014/05/09 职场文书
公司优秀员工获奖感言
2014/08/14 职场文书
大学社团招新的通讯稿
2014/09/10 职场文书
2015公务员试用期工作总结
2014/12/12 职场文书
仓管员岗位职责
2015/02/03 职场文书
酒店员工手册范本
2015/05/14 职场文书
教师个人教学反思
2016/02/23 职场文书
《语言的突破》读后感3篇
2019/12/12 职场文书
vue-router中hash模式与history模式的区别
2021/06/23 Vue.js