Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
在Linux系统上安装Python的Scrapy框架的教程
Jun 11 Python
Python实现简单拆分PDF文件的方法
Jul 30 Python
Python 爬虫模拟登陆知乎
Sep 23 Python
Python排序算法实例代码
Aug 10 Python
python读取目录下最新的文件夹方法
Dec 24 Python
python矩阵/字典实现最短路径算法
Jan 17 Python
python中嵌套函数的实操步骤
Feb 27 Python
Python 获取ftp服务器文件时间的方法
Jul 02 Python
将python运行结果保存至本地文件中的示例讲解
Jul 11 Python
Python面向对象封装操作案例详解 II
Jan 02 Python
通过实例学习Python Excel操作
Jan 06 Python
python爬虫开发之Beautiful Soup模块从安装到详细使用方法与实例
Mar 09 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
php使用codebase生成随机数
2014/03/25 PHP
phpcms手机内容页面添加上一篇和下一篇
2015/06/05 PHP
PHP基于简单递归函数求一个数阶乘的方法示例
2017/04/26 PHP
PHP实现的策略模式简单示例
2017/08/25 PHP
JavaScript事件列表解说
2006/12/22 Javascript
input 高级限制级用法
2009/03/26 Javascript
JavaScript与Image加载事件(onload)、加载状态(complete)
2011/02/14 Javascript
JS代码优化技巧之通俗版(减少js体积)
2011/12/23 Javascript
JS弹出窗口代码大全(详细整理)
2012/12/21 Javascript
JavaScript极简入门教程(二):对象和函数
2014/10/25 Javascript
原生JS实现响应式瀑布流布局
2015/04/02 Javascript
Eclipse引入jquery报错如何解决
2015/12/01 Javascript
分享几种比较简单实用的JavaScript tabel切换
2015/12/31 Javascript
JS获取IMG图片高宽的简单实例
2016/05/17 Javascript
Jquery EasyUI $.Parser
2017/06/02 jQuery
JS仿QQ好友列表展开、收缩功能(第一篇)
2017/07/07 Javascript
BootStrap点击保存后实现模态框自动关闭的思路(模态框)
2017/09/26 Javascript
Vue项目安装插件并保存
2019/01/28 Javascript
vue项目接口管理,所有接口都在apis文件夹中统一管理操作
2020/08/13 Javascript
微信小程序换肤功能实现代码(思路详解)
2020/08/25 Javascript
一起来了解一下JavaScript的预编译(小结)
2021/03/01 Javascript
python3实现zabbix告警推送钉钉的示例
2019/02/20 Python
python之pymysql模块简单应用示例代码
2019/12/16 Python
Python try except异常捕获机制原理解析
2020/04/18 Python
Python爬虫如何应对Cloudflare邮箱加密
2020/06/24 Python
HTML5+CSS3实现机器猫
2016/10/17 HTML / CSS
Fossil美国官网:Fossil手表、手袋、珠宝及配件
2017/02/01 全球购物
美国高档帽子网上商店:Hats.com
2018/08/09 全球购物
如何设定的weblogic的热启动模式(开发模式)与产品发布模式
2012/09/08 面试题
中学生操行评语
2014/04/24 职场文书
迎新晚会策划方案
2014/06/13 职场文书
护士节活动总结
2014/08/29 职场文书
群众路线对照检查材料思想汇报怎么写
2014/09/18 职场文书
弘扬焦裕禄精神践行三严三实心得体会
2014/10/13 职场文书
关于拾金不昧的感谢信(五篇)
2019/10/18 职场文书
在Python中如何使用yield
2021/06/07 Python