Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python httplib,smtplib使用方法
Sep 06 Python
python 判断自定义对象类型
Mar 21 Python
python中使用sys模板和logging模块获取行号和函数名的方法
Apr 15 Python
Python中单例模式总结
Feb 20 Python
Python将一个CSV文件里的数据追加到另一个CSV文件的方法
Jul 04 Python
Pandas删除数据的几种情况(小结)
Jun 21 Python
Python将string转换到float的实例方法
Jul 29 Python
用django设置session过期时间的方法解析
Aug 05 Python
python sqlite的Row对象操作示例
Sep 11 Python
学习Python列表的基础知识汇总
Mar 10 Python
python 使用cx-freeze打包程序的实现
Mar 14 Python
python函数超时自动退出的实操方法
Dec 28 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
php实现utf-8转unicode函数分享
2015/01/06 PHP
PHP将进程作为守护进程的方法
2015/03/19 PHP
php 三元运算符实例详细介绍
2016/12/15 PHP
laravel批量生成假数据的方法
2019/10/09 PHP
有趣的JavaScript数组长度问题代码说明
2011/01/20 Javascript
javascript将数组插入到另一个数组中的代码
2013/01/10 Javascript
JQuery页面图片切换和新闻列表滚动效果的具体实现
2013/09/26 Javascript
jquery中get和post的简单实例
2014/02/04 Javascript
php+js实现倒计时功能
2014/06/02 Javascript
JavaScript字符串对象的concat方法实例(用于连接两个或多个字符串)
2014/10/16 Javascript
基于JavaScript实现熔岩灯效果导航菜单
2017/01/04 Javascript
解决jquery appaend元素中id绑定事件失效的问题
2017/09/12 jQuery
vue2.0$nextTick监听数据渲染完成之后的回调函数方法
2018/09/11 Javascript
JavaScript对JSON数组简单排序操作示例
2019/01/31 Javascript
jQuery 选择器用法基础入门示例
2020/01/04 jQuery
[49:12]完美世界DOTA2联赛PWL S2 Magma vs GXR 第二场 11.29
2020/12/02 DOTA
Python linecache.getline()读取文件中特定一行的脚本
2008/09/06 Python
在Python中通过threading模块定义和调用线程的方法
2016/07/12 Python
基于Django框架利用Ajax实现点赞功能实例代码
2018/08/19 Python
python实现简单加密解密机制
2019/03/19 Python
python3.7 openpyxl 删除指定一列或者一行的代码
2019/10/08 Python
使用批处理脚本自动生成并上传NuGet包(操作方法)
2019/11/19 Python
美的官方商城:Midea
2016/09/14 全球购物
罗马尼亚在线杂货店:Pilulka.ro
2019/09/28 全球购物
Laravel的加密解密与哈希实例讲解
2021/03/24 PHP
《最佳路径》教学反思
2014/04/13 职场文书
书法社团活动总结
2015/05/07 职场文书
2016年主题党日活动总结
2016/04/05 职场文书
一文读懂go中semaphore(信号量)源码
2021/04/03 Golang
Django如何创作一个简单的最小程序
2021/05/12 Python
MySQL 8.0 Online DDL快速加列的相关总结
2021/06/02 MySQL
Python中的datetime包与time包包和模块详情
2022/02/28 Python
中国十大神话动漫电影排行榜 哪吒登顶 白蛇缘起排第七
2022/03/21 国漫
Android Flutter实现图片滑动切换效果
2022/04/07 Java/Android
Tomcat执行startup.bat出现闪退的原因及解决办法
2022/04/20 Servers
vue实现省市区联动 element-china-area-data插件
2022/04/22 Vue.js