详解PyTorch手写数字识别(MNIST数据集)


Posted in Python onAugust 16, 2019

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集

# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络

# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #将梯度归零
      outputs = net(inputs) #将数据传入网络进行前向运算
      loss = criterion(outputs, labels) #得到损失函数
      loss.backward() #反向传播
      optimizer.step() #通过梯度做一步参数更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

测试模型

net.eval() #将模型变换为测试模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用C++封装MySQL的API的教程
May 06 Python
Python开发SQLite3数据库相关操作详解【连接,查询,插入,更新,删除,关闭等】
Jul 27 Python
Python+Turtle动态绘制一棵树实例分享
Jan 16 Python
Python中GeoJson和bokeh-1的使用讲解
Jan 03 Python
如何在Python中实现goto语句的方法
May 18 Python
Python如何调用外部系统命令
Aug 07 Python
python Plotly绘图工具的简单使用
Mar 03 Python
Python定义函数实现累计求和操作
May 03 Python
浅谈pycharm导入pandas包遇到的问题及解决
Jun 01 Python
python下对hsv颜色空间进行量化操作
Jun 04 Python
使用pytorch实现线性回归
Apr 11 Python
端午节将至,用Python爬取粽子数据并可视化,看看网友喜欢哪种粽子吧!
Jun 11 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 #Python
Python 分发包中添加额外文件的方法
Aug 16 #Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 #Python
基于django传递数据到后端的例子
Aug 16 #Python
Django 拆分model和view的实现方法
Aug 16 #Python
利用Python实现kNN算法的代码
Aug 16 #Python
python实现kNN算法识别手写体数字的示例代码
Aug 16 #Python
You might like
PHP面向接口编程 耦合设计模式 简单范例
2011/03/23 PHP
PHP 基于Yii框架中使用smarty模板的方法详解
2013/06/13 PHP
ThinkPHP的I方法使用详解
2014/06/18 PHP
ThinkPHP中I(),U(),$this->post()等函数用法
2014/11/22 PHP
PHP_SELF,SCRIPT_NAME,REQUEST_URI区别
2014/12/24 PHP
ThinkPHP实现递归无级分类――代码少
2015/07/29 PHP
关于图片验证码设计的思考
2007/01/29 Javascript
JS 跳转页面延迟2种方法
2013/03/29 Javascript
JavaScript SetInterval与setTimeout使用方法详解
2013/11/15 Javascript
jquery ztree实现下拉树形框使用到了json数据
2014/05/14 Javascript
js判断当页面无法回退时关闭网页否则就history.go(-1)
2014/08/07 Javascript
JavaScript合并两个数组并去除重复项的方法
2015/06/13 Javascript
Jquery中attr与prop的区别详解
2017/05/27 jQuery
JSON创建键值对(key是中文或者数字)方式详解
2017/08/24 Javascript
BetterScroll 在移动端滚动场景的应用
2017/09/18 Javascript
JS实现定时任务每隔N秒请求后台setInterval定时和ajax请求问题
2017/10/15 Javascript
JS实现十字坐标跟随鼠标效果
2017/12/25 Javascript
Vue.js项目中管理每个页面的头部标签的两种方法
2018/06/25 Javascript
小程序scroll-view安卓机隐藏横向滚动条的实现详解
2019/05/16 Javascript
Angular6项目打包优化的实现方法
2019/12/15 Javascript
详解React 元素渲染
2020/07/07 Javascript
[46:09]2014 DOTA2华西杯精英邀请赛 5 25 LGD VS VG第三场
2014/05/26 DOTA
Python3.x和Python2.x的区别介绍
2013/02/12 Python
python在windows下实现备份程序实例
2014/07/04 Python
对Python3 pyc 文件的使用详解
2019/02/16 Python
python分割一个文本为多个文本的方法
2019/07/22 Python
Python定时任务APScheduler原理及实例解析
2020/05/30 Python
numpy中生成随机数的几种常用函数(小结)
2020/08/18 Python
vue.js刷新当前页面的实例讲解
2020/12/29 Python
使用jquery实现HTML5响应式导航菜单教程
2014/04/02 HTML / CSS
数据库测试通常都包括哪些方面
2015/11/30 面试题
财务管理职业生涯规划范文
2013/12/27 职场文书
副科竞争上岗演讲稿
2014/05/12 职场文书
法律专业求职信
2014/05/24 职场文书
python scipy 稀疏矩阵的使用说明
2021/05/26 Python
Node.js实现断点续传
2021/06/23 Javascript