Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现图片拼接的代码
Jul 02 Python
用pandas中的DataFrame时选取行或列的方法
Jul 11 Python
对pandas的层次索引与取值的新方法详解
Nov 06 Python
不归路系列:Python入门之旅-一定要注意缩进!!!(推荐)
Apr 16 Python
PyTorch笔记之scatter()函数的使用
Feb 12 Python
Python接口测试get请求过程详解
Feb 28 Python
PyQt5实现登录页面
May 30 Python
Python一些基本的图像操作和处理总结
Jun 23 Python
Python集合的基础操作
Nov 01 Python
python数字类型和占位符详情
Mar 13 Python
Pillow图像处理库安装及使用
Apr 12 Python
Python数据可视化之Seaborn的安装及使用
Apr 19 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
php+mysqli实现批量替换数据库表前缀的方法
2014/12/29 PHP
教你识别简单的免查杀PHP后门
2015/09/13 PHP
php验证手机号码
2015/11/11 PHP
常用PHP数组排序函数归纳
2016/08/08 PHP
PHP中的输出echo、print、printf、sprintf、print_r和var_dump的示例代码
2020/12/01 PHP
javascript 文档的编码问题解决
2009/03/01 Javascript
通过身份证号得到出生日期和性别的js代码
2009/11/23 Javascript
详谈 Jquery Ajax异步处理Json数据.
2011/09/09 Javascript
HTML中的setCapture和releaseCapture使用介绍
2012/03/21 Javascript
单击按钮显示隐藏子菜单经典案例
2013/01/04 Javascript
JQuery AJAX 中文乱码问题解决
2013/06/05 Javascript
一个非常全面的javascript URL解析函数和分段URL解析方法
2014/04/12 Javascript
js拼接html注意问题示例探讨
2014/07/14 Javascript
基于AngularJS实现页面滚动到底自动加载数据的功能
2015/10/16 Javascript
很酷的星级评分系统原生JS实现
2016/08/25 Javascript
Boostrap实现的登录界面实例代码
2016/10/09 Javascript
浅谈React 属性和状态的一些总结
2016/11/21 Javascript
完美解决IE不支持Data.parse()的问题
2016/11/24 Javascript
微信小程序左右滑动的实现代码
2017/12/15 Javascript
js中apply和Math.max()函数的问题及区别介绍
2018/03/27 Javascript
layer弹出层自适应高度,垂直水平居中的实现
2019/09/16 Javascript
2019年度web前端面试题总结(主要为Vue面试题)
2020/01/12 Javascript
JavaScript canvas实现跟随鼠标事件
2020/02/10 Javascript
小程序如何定位所在城市及发起周边搜索
2020/02/11 Javascript
Python 中Pickle库的使用详解
2018/02/24 Python
python微信跳一跳系列之自动计算跳一跳距离
2018/02/26 Python
pycharm配置git(图文教程)
2019/08/16 Python
结合CSS3的新特性来总结垂直居中的实现方法
2016/05/30 HTML / CSS
CSS3自定义滚动条样式 ::webkit-scrollbar的示例代码详解
2020/06/01 HTML / CSS
HTML5中语义化 b 和 i 标签
2008/10/17 HTML / CSS
美国领先的医疗警报服务:Philips Lifeline
2018/03/12 全球购物
马来西亚户外装备商店:PTT Outdoor
2019/07/13 全球购物
机械专业个人求职自荐信格式
2013/09/21 职场文书
2015年高校图书馆工作总结
2015/04/30 职场文书
中秋节作文(五年级)之关于月亮
2019/09/11 职场文书
服务器nginx权限被拒绝解决案例
2022/09/23 Servers