手把手教你实现PyTorch的MNIST数据集


Posted in Python onJune 28, 2021

概述

MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.

手把手教你实现PyTorch的MNIST数据集

获取数据

def get_data():
    """获取数据"""

    # 获取测试集
    train = torchvision.datasets.MNIST(root="./data", train=True, download=True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),  # 转换成张量
                                           torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                       ]))
    train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集

    # 获取测试集
    test = torchvision.datasets.MNIST(root="./data", train=False, download=True,
                                      transform=torchvision.transforms.Compose([
                                          torchvision.transforms.ToTensor(),  # 转换成张量
                                          torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                      ]))
    test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练

    # 返回分割好的训练集和测试集
    return train_loader, test_loader

网络模型

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        # 卷积层
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))

        # Dropout层
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)

        # 全连接层
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        """前向传播"""
        
        # [b, 1, 28, 28] => [b, 32, 26, 26]
        out = self.conv1(x)
        out = F.relu(out)
        
        # [b, 32, 26, 26] => [b, 64, 24, 24]
        out = self.conv2(out)
        out = F.relu(out)

        # [b, 64, 24, 24] => [b, 64, 12, 12]
        out = F.max_pool2d(out, 2)
        out = self.dropout1(out)
        
        # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]
        out = torch.flatten(out, 1)
        
        # [b, 9216] => [b, 128]
        out = self.fc1(out)
        out = F.relu(out)

        # [b, 128] => [b, 10]
        out = self.dropout2(out)
        out = self.fc2(out)

        output = F.log_softmax(out, dim=1)

        return output

train 函数

def train(model, epoch, train_loader):
    """训练"""

    # 训练模式
    model.train()

    # 迭代
    for step, (x, y) in enumerate(train_loader):
        # 加速
        if use_cuda:
            model = model.cuda()
            x, y = x.cuda(), y.cuda()

        # 梯度清零
        optimizer.zero_grad()

        output = model(x)

        # 计算损失
        loss = F.nll_loss(output, y)

        # 反向传播
        loss.backward()

        # 更新梯度
        optimizer.step()

        # 打印损失
        if step % 50 == 0:
            print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))

test 函数

def test(model, test_loader):
    """测试"""
    
    # 测试模式
    model.eval()

    # 存放正确个数
    correct = 0

    with torch.no_grad():
        for x, y in test_loader:

            # 加速
            if use_cuda:
                model = model.cuda()
                x, y = x.cuda(), y.cuda()

            # 获取结果
            output = model(x)

            # 预测结果
            pred = output.argmax(dim=1, keepdim=True)

            # 计算准确个数
            correct += pred.eq(y.view_as(pred)).sum().item()

    # 计算准确率
    accuracy = correct / len(test_loader.dataset) * 100

    # 输出准确
    print("Test Accuracy: {}%".format(accuracy))

main 函数

def main():
    # 获取数据
    train_loader, test_loader = get_data()
    
    # 迭代
    for epoch in range(iteration_num):
        print("\n================ epoch: {} ================".format(epoch))
        train(network, epoch, train_loader)
        test(network, test_loader)

完整代码:

import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        # 卷积层
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))

        # Dropout层
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)

        # 全连接层
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        """前向传播"""
        
        # [b, 1, 28, 28] => [b, 32, 26, 26]
        out = self.conv1(x)
        out = F.relu(out)
        
        # [b, 32, 26, 26] => [b, 64, 24, 24]
        out = self.conv2(out)
        out = F.relu(out)

        # [b, 64, 24, 24] => [b, 64, 12, 12]
        out = F.max_pool2d(out, 2)
        out = self.dropout1(out)
        
        # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]
        out = torch.flatten(out, 1)
        
        # [b, 9216] => [b, 128]
        out = self.fc1(out)
        out = F.relu(out)

        # [b, 128] => [b, 10]
        out = self.dropout2(out)
        out = self.fc2(out)

        output = F.log_softmax(out, dim=1)

        return output


# 定义超参数
batch_size = 64  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 5  # 迭代次数
network = Model()  # 实例化网络
print(network)  # 调试输出网络结构
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)  # 优化器

# GPU 加速
use_cuda = torch.cuda.is_available()
print("是否使用 GPU 加速:", use_cuda)


def get_data():
    """获取数据"""

    # 获取测试集
    train = torchvision.datasets.MNIST(root="./data", train=True, download=True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),  # 转换成张量
                                           torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                       ]))
    train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集

    # 获取测试集
    test = torchvision.datasets.MNIST(root="./data", train=False, download=True,
                                      transform=torchvision.transforms.Compose([
                                          torchvision.transforms.ToTensor(),  # 转换成张量
                                          torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                      ]))
    test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练

    # 返回分割好的训练集和测试集
    return train_loader, test_loader


def train(model, epoch, train_loader):
    """训练"""

    # 训练模式
    model.train()

    # 迭代
    for step, (x, y) in enumerate(train_loader):
        # 加速
        if use_cuda:
            model = model.cuda()
            x, y = x.cuda(), y.cuda()

        # 梯度清零
        optimizer.zero_grad()

        output = model(x)

        # 计算损失
        loss = F.nll_loss(output, y)

        # 反向传播
        loss.backward()

        # 更新梯度
        optimizer.step()

        # 打印损失
        if step % 50 == 0:
            print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))


def test(model, test_loader):
    """测试"""

    # 测试模式
    model.eval()

    # 存放正确个数
    correct = 0

    with torch.no_grad():
        for x, y in test_loader:

            # 加速
            if use_cuda:
                model = model.cuda()
                x, y = x.cuda(), y.cuda()

            # 获取结果
            output = model(x)

            # 预测结果
            pred = output.argmax(dim=1, keepdim=True)

            # 计算准确个数
            correct += pred.eq(y.view_as(pred)).sum().item()

    # 计算准确率
    accuracy = correct / len(test_loader.dataset) * 100

    # 输出准确
    print("Test Accuracy: {}%".format(accuracy))


def main():
    # 获取数据
    train_loader, test_loader = get_data()

    # 迭代
    for epoch in range(iteration_num):
        print("\n================ epoch: {} ================".format(epoch))
        train(network, epoch, train_loader)
        test(network, test_loader)

if __name__ == "__main__":
    main()

输出结果:

Model(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
是否使用 GPU 加速: True

================ epoch: 0 ================
Epoch: 0, Step 0, Loss: 2.3131277561187744
Epoch: 0, Step 50, Loss: 1.0419045686721802
Epoch: 0, Step 100, Loss: 0.6259541511535645
Epoch: 0, Step 150, Loss: 0.7194482684135437
Epoch: 0, Step 200, Loss: 0.4020516574382782
Epoch: 0, Step 250, Loss: 0.6890509128570557
Epoch: 0, Step 300, Loss: 0.28660136461257935
Epoch: 0, Step 350, Loss: 0.3277580738067627
Epoch: 0, Step 400, Loss: 0.2750288248062134
Epoch: 0, Step 450, Loss: 0.28428223729133606
Epoch: 0, Step 500, Loss: 0.3514065444469452
Epoch: 0, Step 550, Loss: 0.23386947810649872
Epoch: 0, Step 600, Loss: 0.25338059663772583
Epoch: 0, Step 650, Loss: 0.1743898093700409
Epoch: 0, Step 700, Loss: 0.35752204060554504
Epoch: 0, Step 750, Loss: 0.17575909197330475
Epoch: 0, Step 800, Loss: 0.20604261755943298
Epoch: 0, Step 850, Loss: 0.17389622330665588
Epoch: 0, Step 900, Loss: 0.3188241124153137
Test Accuracy: 96.56%

================ epoch: 1 ================
Epoch: 1, Step 0, Loss: 0.23558208346366882
Epoch: 1, Step 50, Loss: 0.13511177897453308
Epoch: 1, Step 100, Loss: 0.18823786079883575
Epoch: 1, Step 150, Loss: 0.2644936144351959
Epoch: 1, Step 200, Loss: 0.145077645778656
Epoch: 1, Step 250, Loss: 0.30574971437454224
Epoch: 1, Step 300, Loss: 0.2386859953403473
Epoch: 1, Step 350, Loss: 0.08346735686063766
Epoch: 1, Step 400, Loss: 0.10480977594852448
Epoch: 1, Step 450, Loss: 0.07280707359313965
Epoch: 1, Step 500, Loss: 0.20928426086902618
Epoch: 1, Step 550, Loss: 0.20455852150917053
Epoch: 1, Step 600, Loss: 0.10085935145616531
Epoch: 1, Step 650, Loss: 0.13476189970970154
Epoch: 1, Step 700, Loss: 0.19087043404579163
Epoch: 1, Step 750, Loss: 0.0981522724032402
Epoch: 1, Step 800, Loss: 0.1961515098810196
Epoch: 1, Step 850, Loss: 0.041140712797641754
Epoch: 1, Step 900, Loss: 0.250461220741272
Test Accuracy: 98.03%

================ epoch: 2 ================
Epoch: 2, Step 0, Loss: 0.09572553634643555
Epoch: 2, Step 50, Loss: 0.10370486229658127
Epoch: 2, Step 100, Loss: 0.17737184464931488
Epoch: 2, Step 150, Loss: 0.1570713371038437
Epoch: 2, Step 200, Loss: 0.07462178170681
Epoch: 2, Step 250, Loss: 0.18744900822639465
Epoch: 2, Step 300, Loss: 0.09910508990287781
Epoch: 2, Step 350, Loss: 0.08929706364870071
Epoch: 2, Step 400, Loss: 0.07703761011362076
Epoch: 2, Step 450, Loss: 0.10133732110261917
Epoch: 2, Step 500, Loss: 0.1314031481742859
Epoch: 2, Step 550, Loss: 0.10394387692213058
Epoch: 2, Step 600, Loss: 0.11612939089536667
Epoch: 2, Step 650, Loss: 0.17494803667068481
Epoch: 2, Step 700, Loss: 0.11065669357776642
Epoch: 2, Step 750, Loss: 0.061209067702293396
Epoch: 2, Step 800, Loss: 0.14715790748596191
Epoch: 2, Step 850, Loss: 0.03930797800421715
Epoch: 2, Step 900, Loss: 0.18030673265457153
Test Accuracy: 98.46000000000001%

================ epoch: 3 ================
Epoch: 3, Step 0, Loss: 0.09266342222690582
Epoch: 3, Step 50, Loss: 0.0414913073182106
Epoch: 3, Step 100, Loss: 0.2152961939573288
Epoch: 3, Step 150, Loss: 0.12287424504756927
Epoch: 3, Step 200, Loss: 0.13468700647354126
Epoch: 3, Step 250, Loss: 0.11967387050390244
Epoch: 3, Step 300, Loss: 0.11301510035991669
Epoch: 3, Step 350, Loss: 0.037447575479745865
Epoch: 3, Step 400, Loss: 0.04699449613690376
Epoch: 3, Step 450, Loss: 0.05472381412982941
Epoch: 3, Step 500, Loss: 0.09839300811290741
Epoch: 3, Step 550, Loss: 0.07964356243610382
Epoch: 3, Step 600, Loss: 0.08182843774557114
Epoch: 3, Step 650, Loss: 0.05514759197831154
Epoch: 3, Step 700, Loss: 0.13785190880298615
Epoch: 3, Step 750, Loss: 0.062480345368385315
Epoch: 3, Step 800, Loss: 0.120387002825737
Epoch: 3, Step 850, Loss: 0.04458726942539215
Epoch: 3, Step 900, Loss: 0.17119190096855164
Test Accuracy: 98.55000000000001%

================ epoch: 4 ================
Epoch: 4, Step 0, Loss: 0.08094145357608795
Epoch: 4, Step 50, Loss: 0.05615215748548508
Epoch: 4, Step 100, Loss: 0.07766406238079071
Epoch: 4, Step 150, Loss: 0.07915271818637848
Epoch: 4, Step 200, Loss: 0.1301635503768921
Epoch: 4, Step 250, Loss: 0.12118984013795853
Epoch: 4, Step 300, Loss: 0.073218435049057
Epoch: 4, Step 350, Loss: 0.04517696052789688
Epoch: 4, Step 400, Loss: 0.08493026345968246
Epoch: 4, Step 450, Loss: 0.03904269263148308
Epoch: 4, Step 500, Loss: 0.09386837482452393
Epoch: 4, Step 550, Loss: 0.12583576142787933
Epoch: 4, Step 600, Loss: 0.09053893387317657
Epoch: 4, Step 650, Loss: 0.06912104040384293
Epoch: 4, Step 700, Loss: 0.1502612829208374
Epoch: 4, Step 750, Loss: 0.07162325084209442
Epoch: 4, Step 800, Loss: 0.10512275993824005
Epoch: 4, Step 850, Loss: 0.028180215507745743
Epoch: 4, Step 900, Loss: 0.08492615073919296
Test Accuracy: 98.69%

到此这篇关于手把手教你实现PyTorch的MNIST数据集的文章就介绍到这了,更多相关PyTorch MNIST数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
pyqt4教程之实现半透明的天气预报界面示例
Mar 02 Python
Python Sql数据库增删改查操作简单封装
Apr 18 Python
Python实现登录接口的示例代码
Jul 21 Python
Python文件和流(实例讲解)
Sep 12 Python
Python实现的网页截图功能【PyQt4与selenium组件】
Jul 12 Python
python+selenium实现自动抢票功能实例代码
Nov 23 Python
python实现最大子序和(分治+动态规划)
Jul 05 Python
python 三元运算符使用解析
Sep 16 Python
Django 请求Request的具体使用方法
Nov 11 Python
使用python切片实现二维数组复制示例
Nov 26 Python
python 将视频 通过视频帧转换成时间实例
Apr 23 Python
keras 多gpu并行运行案例
Jun 10 Python
PyMongo 查询数据的实现
Jun 28 #Python
浅谈哪个Python库才最适合做数据可视化
总结Python变量的相关知识
详解非极大值抑制算法之Python实现
Python实现生活常识解答机器人
Python办公自动化之教你如何用Python将任意文件转为PDF格式
Python移位密码、仿射变换解密实例代码
You might like
php桌面中心(四) 数据显示
2007/03/11 PHP
php解决DOM乱码的方法示例代码
2016/11/20 PHP
JavaScript 学习笔记(十二) dom
2010/01/21 Javascript
JQuery 动态扩展对象之另类视角
2010/05/25 Javascript
js 判断一个元素是否在页面中存在
2012/12/27 Javascript
jQuery的attr与prop使用介绍
2013/10/10 Javascript
写出高效jquery代码的19条指南
2014/03/19 Javascript
在JavaScript里防止事件函数高频触发和高频调用的方法
2014/09/06 Javascript
吐槽一下我所了解的Node.js
2014/10/08 Javascript
javascript笛卡尔积算法实现方法
2015/04/08 Javascript
js实现顶部可折叠的菜单工具栏效果实例
2015/05/09 Javascript
JS数组array元素的添加和删除方法代码实例
2015/06/01 Javascript
学习使用AngularJS文件上传控件
2016/02/16 Javascript
jquery设置css样式的多种方法(总结)
2017/02/21 Javascript
Angular 4.x 路由快速入门学习
2017/05/03 Javascript
使用Angular CLI生成 Angular 5项目教程详解
2018/03/18 Javascript
js canvas实现橡皮擦效果
2018/12/20 Javascript
关于layui 弹出层一闪而过就消失的解决方法
2019/09/09 Javascript
JS实现基本的网页计算器功能示例
2020/01/16 Javascript
vue 虚拟DOM的原理
2020/10/03 Javascript
解决vant的Toast组件时提示not defined的问题
2020/11/11 Javascript
Python同时向控制台和文件输出日志logging的方法
2015/05/26 Python
Python中摘要算法MD5,SHA1简介及应用实例代码
2018/01/09 Python
利用Opencv中Houghline方法实现直线检测
2018/02/11 Python
python解决js文件utf-8编码乱码问题(推荐)
2018/05/02 Python
Python爬虫框架scrapy实现的文件下载功能示例
2018/08/04 Python
python读取文件名并改名字的实例
2019/01/07 Python
详解python模块pychartdir安装及导入问题
2020/10/22 Python
纯css3实现的动画按钮的实例教程
2014/11/17 HTML / CSS
雅萌 (YA-MAN) :日本美容家电领域的龙头企业
2017/05/12 全球购物
机械专业毕业生推荐信范文
2013/11/25 职场文书
中层竞聘演讲稿
2014/01/09 职场文书
《一本男孩子必读的书》教学反思
2014/02/19 职场文书
请假条范文大全
2014/04/10 职场文书
售房协议书范本
2015/08/11 职场文书
《落花生》教学反思
2016/02/16 职场文书