Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python使用心得之获得github代码库列表
Jun 25 Python
Python实现的一个简单LRU cache
Sep 26 Python
python使用Flask框架获取用户IP地址的方法
Mar 21 Python
在Python中关于中文编码问题的处理建议
Apr 08 Python
初步探究Python程序的执行原理
Apr 11 Python
Python爬取数据保存为Json格式的代码示例
Apr 09 Python
Django实现文件上传下载
Oct 06 Python
Django 再谈一谈json序列化
Mar 16 Python
python实现交并比IOU教程
Apr 16 Python
Python自动登录QQ的实现示例
Aug 28 Python
python 进制转换 int、bin、oct、hex的原理
Jan 13 Python
使用Python快速打开一个百万行级别的超大Excel文件的方法
Mar 02 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
PHP中使用Imagick操作PSD文件实例
2015/01/26 PHP
Ajax提交表单时验证码自动验证 php后端验证码检测
2016/07/20 PHP
PHP合并数组的2种方法小结
2016/11/24 PHP
yii2 url重写并隐藏index.php方法
2018/12/10 PHP
在phpstudy集成环境下的nginx服务器下配置url重写
2019/12/02 PHP
Alliance vs Liquid BO3 第二场2.13
2021/03/10 DOTA
List Information About the Binary Files Used by an Application
2007/06/11 Javascript
js 表单验证方法(实用)
2009/04/28 Javascript
动感效果的TAB选项卡jquery 插件
2011/07/09 Javascript
JavaScript控制各种浏览器全屏模式的方法、属性和事件介绍
2014/04/03 Javascript
jQuery scrollFix滚动定位插件
2015/04/01 Javascript
JavaScript中的Function函数
2015/08/27 Javascript
手机软键盘弹出时影响布局的解决方法
2016/12/15 Javascript
js CSS3实现卡牌旋转切换效果
2017/07/04 Javascript
前端 javascript 实现文件下载的示例
2020/11/24 Javascript
[03:39]DOTA2英雄梦之声_第05期_幽鬼
2014/06/23 DOTA
python创建一个最简单http webserver服务器的方法
2015/05/08 Python
python进程和线程用法知识点总结
2019/05/28 Python
pyqt5之将textBrowser的内容写入txt文档的方法
2019/06/21 Python
Python实现微信好友的数据分析
2019/12/16 Python
Python编程快速上手——PDF文件操作案例分析
2020/02/28 Python
python使用openpyxl操作excel的方法步骤
2020/05/28 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
2020/07/07 Python
python Tornado框架的使用示例
2020/10/19 Python
python实现一个简单RPC框架的示例
2020/10/28 Python
Scrapy-Redis之RedisSpider与RedisCrawlSpider详解
2020/11/18 Python
运动会表扬稿大全
2014/01/16 职场文书
法律六进活动方案
2014/03/13 职场文书
小学生母亲节演讲稿
2014/05/07 职场文书
安全伴我行演讲稿
2014/09/04 职场文书
镇政府副镇长群众路线专题民主生活会对照检查材料
2014/09/19 职场文书
2014年度工作总结报告
2014/12/15 职场文书
出国留学单位推荐信
2015/03/26 职场文书
2019年大学生暑期社会实践调查报告模板
2019/11/07 职场文书
Python爬虫基础之简单说一下scrapy的框架结构
2021/06/26 Python
修改并编译golang源码的操作步骤
2021/07/25 Golang