手把手教你实现PyTorch的MNIST数据集


Posted in Python onJune 28, 2021

概述

MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.

手把手教你实现PyTorch的MNIST数据集

获取数据

def get_data():
    """获取数据"""

    # 获取测试集
    train = torchvision.datasets.MNIST(root="./data", train=True, download=True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),  # 转换成张量
                                           torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                       ]))
    train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集

    # 获取测试集
    test = torchvision.datasets.MNIST(root="./data", train=False, download=True,
                                      transform=torchvision.transforms.Compose([
                                          torchvision.transforms.ToTensor(),  # 转换成张量
                                          torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                      ]))
    test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练

    # 返回分割好的训练集和测试集
    return train_loader, test_loader

网络模型

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        # 卷积层
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))

        # Dropout层
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)

        # 全连接层
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        """前向传播"""
        
        # [b, 1, 28, 28] => [b, 32, 26, 26]
        out = self.conv1(x)
        out = F.relu(out)
        
        # [b, 32, 26, 26] => [b, 64, 24, 24]
        out = self.conv2(out)
        out = F.relu(out)

        # [b, 64, 24, 24] => [b, 64, 12, 12]
        out = F.max_pool2d(out, 2)
        out = self.dropout1(out)
        
        # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]
        out = torch.flatten(out, 1)
        
        # [b, 9216] => [b, 128]
        out = self.fc1(out)
        out = F.relu(out)

        # [b, 128] => [b, 10]
        out = self.dropout2(out)
        out = self.fc2(out)

        output = F.log_softmax(out, dim=1)

        return output

train 函数

def train(model, epoch, train_loader):
    """训练"""

    # 训练模式
    model.train()

    # 迭代
    for step, (x, y) in enumerate(train_loader):
        # 加速
        if use_cuda:
            model = model.cuda()
            x, y = x.cuda(), y.cuda()

        # 梯度清零
        optimizer.zero_grad()

        output = model(x)

        # 计算损失
        loss = F.nll_loss(output, y)

        # 反向传播
        loss.backward()

        # 更新梯度
        optimizer.step()

        # 打印损失
        if step % 50 == 0:
            print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))

test 函数

def test(model, test_loader):
    """测试"""
    
    # 测试模式
    model.eval()

    # 存放正确个数
    correct = 0

    with torch.no_grad():
        for x, y in test_loader:

            # 加速
            if use_cuda:
                model = model.cuda()
                x, y = x.cuda(), y.cuda()

            # 获取结果
            output = model(x)

            # 预测结果
            pred = output.argmax(dim=1, keepdim=True)

            # 计算准确个数
            correct += pred.eq(y.view_as(pred)).sum().item()

    # 计算准确率
    accuracy = correct / len(test_loader.dataset) * 100

    # 输出准确
    print("Test Accuracy: {}%".format(accuracy))

main 函数

def main():
    # 获取数据
    train_loader, test_loader = get_data()
    
    # 迭代
    for epoch in range(iteration_num):
        print("\n================ epoch: {} ================".format(epoch))
        train(network, epoch, train_loader)
        test(network, test_loader)

完整代码:

import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        # 卷积层
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))

        # Dropout层
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)

        # 全连接层
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        """前向传播"""
        
        # [b, 1, 28, 28] => [b, 32, 26, 26]
        out = self.conv1(x)
        out = F.relu(out)
        
        # [b, 32, 26, 26] => [b, 64, 24, 24]
        out = self.conv2(out)
        out = F.relu(out)

        # [b, 64, 24, 24] => [b, 64, 12, 12]
        out = F.max_pool2d(out, 2)
        out = self.dropout1(out)
        
        # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]
        out = torch.flatten(out, 1)
        
        # [b, 9216] => [b, 128]
        out = self.fc1(out)
        out = F.relu(out)

        # [b, 128] => [b, 10]
        out = self.dropout2(out)
        out = self.fc2(out)

        output = F.log_softmax(out, dim=1)

        return output


# 定义超参数
batch_size = 64  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 5  # 迭代次数
network = Model()  # 实例化网络
print(network)  # 调试输出网络结构
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)  # 优化器

# GPU 加速
use_cuda = torch.cuda.is_available()
print("是否使用 GPU 加速:", use_cuda)


def get_data():
    """获取数据"""

    # 获取测试集
    train = torchvision.datasets.MNIST(root="./data", train=True, download=True,
                                       transform=torchvision.transforms.Compose([
                                           torchvision.transforms.ToTensor(),  # 转换成张量
                                           torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                       ]))
    train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集

    # 获取测试集
    test = torchvision.datasets.MNIST(root="./data", train=False, download=True,
                                      transform=torchvision.transforms.Compose([
                                          torchvision.transforms.ToTensor(),  # 转换成张量
                                          torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化
                                      ]))
    test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练

    # 返回分割好的训练集和测试集
    return train_loader, test_loader


def train(model, epoch, train_loader):
    """训练"""

    # 训练模式
    model.train()

    # 迭代
    for step, (x, y) in enumerate(train_loader):
        # 加速
        if use_cuda:
            model = model.cuda()
            x, y = x.cuda(), y.cuda()

        # 梯度清零
        optimizer.zero_grad()

        output = model(x)

        # 计算损失
        loss = F.nll_loss(output, y)

        # 反向传播
        loss.backward()

        # 更新梯度
        optimizer.step()

        # 打印损失
        if step % 50 == 0:
            print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))


def test(model, test_loader):
    """测试"""

    # 测试模式
    model.eval()

    # 存放正确个数
    correct = 0

    with torch.no_grad():
        for x, y in test_loader:

            # 加速
            if use_cuda:
                model = model.cuda()
                x, y = x.cuda(), y.cuda()

            # 获取结果
            output = model(x)

            # 预测结果
            pred = output.argmax(dim=1, keepdim=True)

            # 计算准确个数
            correct += pred.eq(y.view_as(pred)).sum().item()

    # 计算准确率
    accuracy = correct / len(test_loader.dataset) * 100

    # 输出准确
    print("Test Accuracy: {}%".format(accuracy))


def main():
    # 获取数据
    train_loader, test_loader = get_data()

    # 迭代
    for epoch in range(iteration_num):
        print("\n================ epoch: {} ================".format(epoch))
        train(network, epoch, train_loader)
        test(network, test_loader)

if __name__ == "__main__":
    main()

输出结果:

Model(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
是否使用 GPU 加速: True

================ epoch: 0 ================
Epoch: 0, Step 0, Loss: 2.3131277561187744
Epoch: 0, Step 50, Loss: 1.0419045686721802
Epoch: 0, Step 100, Loss: 0.6259541511535645
Epoch: 0, Step 150, Loss: 0.7194482684135437
Epoch: 0, Step 200, Loss: 0.4020516574382782
Epoch: 0, Step 250, Loss: 0.6890509128570557
Epoch: 0, Step 300, Loss: 0.28660136461257935
Epoch: 0, Step 350, Loss: 0.3277580738067627
Epoch: 0, Step 400, Loss: 0.2750288248062134
Epoch: 0, Step 450, Loss: 0.28428223729133606
Epoch: 0, Step 500, Loss: 0.3514065444469452
Epoch: 0, Step 550, Loss: 0.23386947810649872
Epoch: 0, Step 600, Loss: 0.25338059663772583
Epoch: 0, Step 650, Loss: 0.1743898093700409
Epoch: 0, Step 700, Loss: 0.35752204060554504
Epoch: 0, Step 750, Loss: 0.17575909197330475
Epoch: 0, Step 800, Loss: 0.20604261755943298
Epoch: 0, Step 850, Loss: 0.17389622330665588
Epoch: 0, Step 900, Loss: 0.3188241124153137
Test Accuracy: 96.56%

================ epoch: 1 ================
Epoch: 1, Step 0, Loss: 0.23558208346366882
Epoch: 1, Step 50, Loss: 0.13511177897453308
Epoch: 1, Step 100, Loss: 0.18823786079883575
Epoch: 1, Step 150, Loss: 0.2644936144351959
Epoch: 1, Step 200, Loss: 0.145077645778656
Epoch: 1, Step 250, Loss: 0.30574971437454224
Epoch: 1, Step 300, Loss: 0.2386859953403473
Epoch: 1, Step 350, Loss: 0.08346735686063766
Epoch: 1, Step 400, Loss: 0.10480977594852448
Epoch: 1, Step 450, Loss: 0.07280707359313965
Epoch: 1, Step 500, Loss: 0.20928426086902618
Epoch: 1, Step 550, Loss: 0.20455852150917053
Epoch: 1, Step 600, Loss: 0.10085935145616531
Epoch: 1, Step 650, Loss: 0.13476189970970154
Epoch: 1, Step 700, Loss: 0.19087043404579163
Epoch: 1, Step 750, Loss: 0.0981522724032402
Epoch: 1, Step 800, Loss: 0.1961515098810196
Epoch: 1, Step 850, Loss: 0.041140712797641754
Epoch: 1, Step 900, Loss: 0.250461220741272
Test Accuracy: 98.03%

================ epoch: 2 ================
Epoch: 2, Step 0, Loss: 0.09572553634643555
Epoch: 2, Step 50, Loss: 0.10370486229658127
Epoch: 2, Step 100, Loss: 0.17737184464931488
Epoch: 2, Step 150, Loss: 0.1570713371038437
Epoch: 2, Step 200, Loss: 0.07462178170681
Epoch: 2, Step 250, Loss: 0.18744900822639465
Epoch: 2, Step 300, Loss: 0.09910508990287781
Epoch: 2, Step 350, Loss: 0.08929706364870071
Epoch: 2, Step 400, Loss: 0.07703761011362076
Epoch: 2, Step 450, Loss: 0.10133732110261917
Epoch: 2, Step 500, Loss: 0.1314031481742859
Epoch: 2, Step 550, Loss: 0.10394387692213058
Epoch: 2, Step 600, Loss: 0.11612939089536667
Epoch: 2, Step 650, Loss: 0.17494803667068481
Epoch: 2, Step 700, Loss: 0.11065669357776642
Epoch: 2, Step 750, Loss: 0.061209067702293396
Epoch: 2, Step 800, Loss: 0.14715790748596191
Epoch: 2, Step 850, Loss: 0.03930797800421715
Epoch: 2, Step 900, Loss: 0.18030673265457153
Test Accuracy: 98.46000000000001%

================ epoch: 3 ================
Epoch: 3, Step 0, Loss: 0.09266342222690582
Epoch: 3, Step 50, Loss: 0.0414913073182106
Epoch: 3, Step 100, Loss: 0.2152961939573288
Epoch: 3, Step 150, Loss: 0.12287424504756927
Epoch: 3, Step 200, Loss: 0.13468700647354126
Epoch: 3, Step 250, Loss: 0.11967387050390244
Epoch: 3, Step 300, Loss: 0.11301510035991669
Epoch: 3, Step 350, Loss: 0.037447575479745865
Epoch: 3, Step 400, Loss: 0.04699449613690376
Epoch: 3, Step 450, Loss: 0.05472381412982941
Epoch: 3, Step 500, Loss: 0.09839300811290741
Epoch: 3, Step 550, Loss: 0.07964356243610382
Epoch: 3, Step 600, Loss: 0.08182843774557114
Epoch: 3, Step 650, Loss: 0.05514759197831154
Epoch: 3, Step 700, Loss: 0.13785190880298615
Epoch: 3, Step 750, Loss: 0.062480345368385315
Epoch: 3, Step 800, Loss: 0.120387002825737
Epoch: 3, Step 850, Loss: 0.04458726942539215
Epoch: 3, Step 900, Loss: 0.17119190096855164
Test Accuracy: 98.55000000000001%

================ epoch: 4 ================
Epoch: 4, Step 0, Loss: 0.08094145357608795
Epoch: 4, Step 50, Loss: 0.05615215748548508
Epoch: 4, Step 100, Loss: 0.07766406238079071
Epoch: 4, Step 150, Loss: 0.07915271818637848
Epoch: 4, Step 200, Loss: 0.1301635503768921
Epoch: 4, Step 250, Loss: 0.12118984013795853
Epoch: 4, Step 300, Loss: 0.073218435049057
Epoch: 4, Step 350, Loss: 0.04517696052789688
Epoch: 4, Step 400, Loss: 0.08493026345968246
Epoch: 4, Step 450, Loss: 0.03904269263148308
Epoch: 4, Step 500, Loss: 0.09386837482452393
Epoch: 4, Step 550, Loss: 0.12583576142787933
Epoch: 4, Step 600, Loss: 0.09053893387317657
Epoch: 4, Step 650, Loss: 0.06912104040384293
Epoch: 4, Step 700, Loss: 0.1502612829208374
Epoch: 4, Step 750, Loss: 0.07162325084209442
Epoch: 4, Step 800, Loss: 0.10512275993824005
Epoch: 4, Step 850, Loss: 0.028180215507745743
Epoch: 4, Step 900, Loss: 0.08492615073919296
Test Accuracy: 98.69%

到此这篇关于手把手教你实现PyTorch的MNIST数据集的文章就介绍到这了,更多相关PyTorch MNIST数据集内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
简单的抓取淘宝图片的Python爬虫
Dec 25 Python
Python使用pymysql小技巧
Jun 04 Python
python list元素为tuple时的排序方法
Apr 18 Python
使用Python的Django和layim实现即时通讯的方法
May 25 Python
django主动抛出403异常的方法详解
Jan 04 Python
通过python改变图片特定区域的颜色详解
Jul 15 Python
django中使用POST方法获取POST数据
Aug 20 Python
python元组的概念知识点
Nov 19 Python
python 检查数据中是否有缺失值,删除缺失值的方式
Dec 02 Python
python异常处理、自定义异常、断言原理与用法分析
Mar 23 Python
详解python算法常用技巧与内置库
Oct 17 Python
如何基于python实现单目三维重建详解
Jun 25 Python
PyMongo 查询数据的实现
Jun 28 #Python
浅谈哪个Python库才最适合做数据可视化
总结Python变量的相关知识
详解非极大值抑制算法之Python实现
Python实现生活常识解答机器人
Python办公自动化之教你如何用Python将任意文件转为PDF格式
Python移位密码、仿射变换解密实例代码
You might like
PHP会话控制:Session与Cookie详解
2014/09/27 PHP
PHP实现将几张照片拼接到一起的合成图片功能【便于整体打印输出】
2017/11/14 PHP
jquery 入门教程 [翻译] 推荐
2009/08/17 Javascript
使用Jquery搭建最佳用户体验的登录页面之记住密码自动登录功能(含后台代码)
2011/07/10 Javascript
javascript语言结构小记(一)
2011/09/10 Javascript
zShowBox 图片放大展示jquery版 兼容性
2011/09/24 Javascript
从面试题学习Javascript 面向对象(创建对象)
2012/03/30 Javascript
JavaScript中:表达式和语句的区别[译]
2012/09/17 Javascript
解析Javascript小括号“()”的多义性
2013/12/03 Javascript
单击和双击事件的冲突处理示例代码
2014/04/03 Javascript
原生js编写设为首页兼容ie、火狐和谷歌
2014/06/05 Javascript
基于jquery实现的仿优酷图片轮播特效代码
2016/01/13 Javascript
浅析C/C++,Java,PHP,JavaScript,Json数组、对象赋值时最后一个元素后面是否可以带逗号
2016/03/22 Javascript
jQuery插件FusionWidgets实现的AngularGauge图效果示例【附demo源码】
2017/03/23 jQuery
详解升级react-router 4 踩坑指南
2017/08/14 Javascript
页面点击小红心js实现代码
2018/05/26 Javascript
解决vue-cli webpack打包后加载资源的路径问题
2018/09/25 Javascript
VUE实现密码验证与提示功能
2019/10/18 Javascript
Python列表list操作符实例分析【标准类型操作符、切片、连接字符、列表解析、重复操作等】
2017/07/24 Python
Python3实现爬取指定百度贴吧页面并保存页面数据生成本地文档的方法
2018/04/22 Python
详解python读取和输出到txt
2019/03/29 Python
Python面向对象之类和实例用法分析
2019/06/08 Python
Opencv求取连通区域重心实例
2020/06/04 Python
python中random模块详解
2021/03/01 Python
html5简介_动力节点Java学院整理
2017/07/07 HTML / CSS
迪卡侬比利时官网:Decathlon比利时
2019/12/28 全球购物
会计岗位描述
2014/02/22 职场文书
实验室标语
2014/06/21 职场文书
爬山的活动方案
2014/08/16 职场文书
宝葫芦的秘密观后感
2015/06/11 职场文书
领导干部学习十八届五中全会精神心得体会
2016/01/05 职场文书
爱心捐款倡议书:点燃希望,传递温暖
2019/11/04 职场文书
关于golang高并发的实现与注意事项说明
2021/05/08 Golang
你喜欢篮球吗?Python实现篮球游戏
2021/06/11 Python
详解JAVA中的OPTIONAL
2021/06/14 Java/Android
Apache Linkis 中间件架构及快速安装步骤
2022/03/16 Servers