Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Windows下python2.7.8安装图文教程
May 26 Python
python绘制多个曲线的折线图
Mar 23 Python
为什么str(float)在Python 3中比Python 2返回更多的数字
Oct 16 Python
python 用下标截取字符串的实例
Dec 25 Python
Python实现的拉格朗日插值法示例
Jan 08 Python
python生成随机红包的实例写法
Sep 02 Python
Python参数传递机制传值和传引用原理详解
May 22 Python
tensorflow实现残差网络方式(mnist数据集)
May 26 Python
为什么称python为胶水语言
Jun 16 Python
查找适用于matplotlib的中文字体名称与实际文件名对应关系的方法
Jan 05 Python
python 实现定时任务的四种方式
Apr 01 Python
Python使用psutil库对系统数据进行采集监控的方法
Aug 23 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
php 保留小数点
2009/04/21 PHP
新手学习PHP的一些基础知识分享
2011/07/27 PHP
php和jquery实现地图区域数据统计展示数据示例
2014/02/12 PHP
py文件转exe时包含paramiko模块出错解决方法
2016/08/12 PHP
PHP基于mssql扩展远程连接MSSQL的简单实现方法
2016/10/08 PHP
jquery uaMatch源代码
2011/02/14 Javascript
js 通用订单代码
2013/12/23 Javascript
原生javascript实现的一个简单动画效果
2016/03/30 Javascript
JavaScript中获取HTML元素值的三种方法
2016/06/20 Javascript
vue.js中$watch的用法示例
2016/10/04 Javascript
用Nodejs搭建服务器访问html、css、JS等静态资源文件
2017/04/28 NodeJs
值得分享和收藏的xmlplus组件学习教程
2017/05/05 Javascript
JavaScript设计模式之原型模式分析【ES5与ES6】
2018/07/26 Javascript
Angular4.x Event (DOM事件和自定义事件详解)
2018/10/09 Javascript
vue使用v-if v-show页面闪烁,div闪现的解决方法
2018/10/12 Javascript
微信小程序BindTap快速连续点击目标页面跳转多次问题处理
2019/04/08 Javascript
[33:33]完美世界DOTA2联赛PWL S2 FTD.C vs SZ 第二场 11.27
2020/11/30 DOTA
详解基于django实现的webssh简单例子
2018/07/17 Python
对Python正则匹配IP、Url、Mail的方法详解
2018/12/25 Python
python实现socket+threading处理多连接的方法
2019/07/23 Python
Python新手学习raise用法
2020/06/03 Python
Python3+Flask安装使用教程详解
2021/02/16 Python
前后端结合实现amazeUI分页效果
2020/08/21 HTML / CSS
美国高档帽子网上商店:Hats.com
2018/08/09 全球购物
网上卖盒饭创业计划书
2014/01/26 职场文书
简历的自我评价范文
2014/02/04 职场文书
小学社团活动总结
2014/06/27 职场文书
施工安全汇报材料
2014/08/17 职场文书
学校四风对照检查材料
2014/08/28 职场文书
正规欠条模板
2015/07/03 职场文书
2016领导干部廉洁自律心得体会
2016/01/13 职场文书
CSS的class与id常用的命名规则
2021/05/18 HTML / CSS
Python中的pprint模块
2021/11/27 Python
JavaCV实现照片马赛克效果
2022/01/22 Java/Android
使用ICOM IC-R9500接收机同时测评十台收音机中波接收性能
2022/05/10 无线电
Nginx如何获取自定义请求header头和URL参数详解
2022/07/23 Servers