Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
关于Python元祖,列表,字典,集合的比较
Jan 06 Python
使用Python实现windows下的抓包与解析
Jan 15 Python
windows下安装python的C扩展编译环境(解决Unable to find vcvarsall.bat)
Feb 21 Python
python爱心表白 每天都是浪漫七夕!
Aug 18 Python
解决python opencv无法显示图片的问题
Oct 28 Python
Python之使用adb shell命令启动应用的方法详解
Jan 07 Python
解决使用PyCharm时无法启动控制台的问题
Jan 19 Python
详解Python3 基本数据类型
Apr 19 Python
python3 selenium自动化 frame表单嵌套的切换方法
Aug 23 Python
Python PyInstaller库基本使用方法分析
Dec 12 Python
python使用gdal对shp读取,新建和更新的实例
Mar 10 Python
Python之Matplotlib绘制热力图和面积图
Apr 13 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
基于mysql的bbs设计(五)
2006/10/09 PHP
PHP 和 XML: 使用expat函数(一)
2006/10/09 PHP
php获取一个变量的名字的方法
2014/09/05 PHP
ThinkPHP中关联查询实例
2014/12/02 PHP
详解PHP的Yii框架中组件行为的属性注入和方法注入
2016/03/18 PHP
PHPExcel 修改已存在Excel的方法
2018/05/03 PHP
javascript 动态设置已知select的option的value值的代码
2009/12/16 Javascript
jQuery学习笔记之DOM对象和jQuery对象
2010/12/22 Javascript
用innerhtml提高页面打开速度的方法
2013/08/02 Javascript
原生JavaScript实现动态省市县三级联动下拉框菜单实例代码
2016/02/03 Javascript
NodeJS整合银联网关支付(DEMO)
2016/11/09 NodeJs
Vue组件和Route的生命周期实例详解
2018/02/10 Javascript
vue 点击按钮增加一行的方法
2018/09/07 Javascript
微信小程序使用component自定义toast弹窗效果
2018/11/27 Javascript
爬虫利器Puppeteer实战
2019/01/09 Javascript
手把手教你使用TypeScript开发Node.js应用
2019/05/06 Javascript
通过js给网页加上水印背景实例
2019/06/17 Javascript
JS常用排序方法实例代码解析
2020/03/03 Javascript
前端开发基础javaScript的六大作用
2020/08/06 Javascript
[02:43]2018DOTA2亚洲邀请赛主赛事首日TOP5
2018/04/04 DOTA
[41:13]完美世界DOTA2联赛PWL S2 Forest vs Rebirth 第一场 11.20
2020/11/20 DOTA
python中关于日期时间处理的问答集锦
2013/03/08 Python
python使用Flask框架获取用户IP地址的方法
2015/03/21 Python
Python3 获取一大段文本之间两个关键字之间的内容方法
2018/10/11 Python
python高斯分布概率密度函数的使用详解
2019/07/10 Python
Python求正态分布曲线下面积实例
2019/11/20 Python
Python创建数字列表的示例
2019/11/28 Python
TensorFLow 变量命名空间实例
2020/02/11 Python
Visual Studio Code搭建django项目的方法步骤
2020/09/17 Python
全球领先的鞋类零售商:The Walking Company
2016/07/21 全球购物
英国莱斯特松木橡木家具网上商店:Choice Furniture Superstore
2019/07/05 全球购物
DBA的职责都有哪些
2012/05/16 面试题
社区文化建设方案
2014/05/02 职场文书
个人承诺书格式范文
2015/04/29 职场文书
导游词之凤凰古城
2019/10/22 职场文书
MongoDB数据库之添删改查
2022/04/26 MongoDB