Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
在树莓派2或树莓派B+上安装Python和OpenCV的教程
Mar 30 Python
Python写的一个定时重跑获取数据库数据
Dec 28 Python
python使用 HTMLTestRunner.py生成测试报告
Oct 20 Python
Python3调用微信企业号API发送文本消息代码示例
Nov 10 Python
Python中getpass模块无回显输入源码解析
Jan 11 Python
python获取微信企业号打卡数据并生成windows计划任务
Apr 30 Python
django3.02模板中的超链接配置实例代码
Feb 04 Python
Java byte数组操纵方式代码实例解析
Jul 22 Python
Python如何输出百分比
Jul 31 Python
利用Python如何制作贪吃蛇及AI版贪吃蛇详解
Aug 24 Python
Python Django路径配置实现过程解析
Nov 05 Python
python实现大文本文件分割成多个小文件
Apr 20 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
全国FM电台频率大全 - 29 青海省
2020/03/11 无线电
通俗易懂的php防注入代码
2010/04/07 PHP
PHP 魔术函数使用说明
2010/05/14 PHP
ci检测是ajax还是页面post提交数据的方法
2014/11/10 PHP
简单谈谈php中ob_flush和flush的区别
2014/11/27 PHP
理解JavaScript变量作用域更轻松
2009/10/25 Javascript
EXT窗口Window及对话框MessageBox
2011/01/27 Javascript
JS实现切换标签页效果实例代码
2013/11/01 Javascript
javascript顺序加载图片的方法
2015/07/18 Javascript
基于JavaScript实现右键菜单和拖拽功能
2016/11/28 Javascript
解决angularjs service中依赖注入$scope报错的问题
2018/10/02 Javascript
Javascript中绑定click事件的四种方式介绍
2018/10/26 Javascript
vue2 v-model/v-text 中使用过滤器的方法示例
2019/05/09 Javascript
js中关于Blob对象的介绍与使用
2019/11/29 Javascript
基于Python和Scikit-Learn的机器学习探索
2017/10/16 Python
基于Python List的赋值方法
2018/06/23 Python
django中forms组件的使用与注意
2019/07/08 Python
python单向循环链表原理与实现方法示例
2019/12/03 Python
keras导入weights方式
2020/06/12 Python
最简单的matplotlib安装教程(小白)
2020/07/28 Python
CSS3常用的几种颜色渐变模式总结
2016/11/18 HTML / CSS
水产养殖学应届生求职信
2013/09/29 职场文书
护理实习自我鉴定
2013/12/14 职场文书
政法大学毕业生自荐信范文
2014/01/01 职场文书
教师绩效考核方案
2014/01/21 职场文书
关于期中考试的反思
2014/02/02 职场文书
节约用水标语
2014/06/11 职场文书
2014年煤矿安全工作总结
2014/12/04 职场文书
班主任高考寄语
2015/02/26 职场文书
总账会计岗位职责
2015/04/02 职场文书
毕业证明模板
2015/06/19 职场文书
演讲比赛通讯稿
2015/07/18 职场文书
2016年猴年新春致辞
2015/08/01 职场文书
基于Redis zSet实现滑动窗口对短信进行防刷限流的问题
2022/02/12 Redis
Tomcat执行startup.bat出现闪退的原因及解决办法
2022/04/20 Servers
Java Spring Boot 正确读取配置文件中的属性的值
2022/04/20 Java/Android