Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现分割文件及合并文件的方法
Jul 10 Python
python实现自主查询实时天气
Jun 22 Python
Python3实现获取图片文字里中文的方法分析
Dec 13 Python
对python当中不在本路径的py文件的引用详解
Dec 15 Python
对pandas的算术运算和数据对齐实例详解
Dec 22 Python
python sort、sort_index方法代码实例
Mar 28 Python
Python学习笔记之变量、自定义函数用法示例
May 28 Python
基于python的socket实现单机五子棋到双人对战
Mar 24 Python
Python解决pip install时出现的Could not fetch URL问题
Aug 01 Python
解决Pytorch 训练与测试时爆显存(out of memory)的问题
Aug 20 Python
python中lower函数实现方法及用法讲解
Dec 23 Python
python基础之while循环语句的使用
Apr 20 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
PHP常用函数小技巧
2008/09/11 PHP
PHP中的string类型使用说明
2010/07/27 PHP
phpmyadmin提示The mbstring extension is missing的解决方法
2014/12/17 PHP
PHP中通过getopt解析GNU C风格命令行选项
2019/11/18 PHP
JSON.parse 解析字符串出错的解决方法
2010/07/08 Javascript
nodejs win7下安装方法
2012/05/24 NodeJs
JS对select控件option选项的增删改查示例代码
2013/10/21 Javascript
js hover 定时器(实例代码)
2013/11/12 Javascript
js实现感应鼠标图片透明度变化的方法
2015/02/20 Javascript
jquery读取xml文件实现省市县三级联动的方法
2015/05/29 Javascript
javascript动态生成树形菜单的方法
2015/11/14 Javascript
JSONP原理及简单实现
2016/06/08 Javascript
AngularJS中比较两个数组是否相同
2016/08/24 Javascript
Javascript 实现放大镜效果实例详解
2016/12/03 Javascript
js操作浏览器的参数方法
2017/01/21 Javascript
JavaScript拖动层Div代码
2017/03/01 Javascript
前端构建工具之gulp的配置与搭建详解
2017/06/12 Javascript
vue2.0组件之间传值、通信的多种方式(干货)
2018/02/10 Javascript
nodejs 生成和导出 word的实例代码
2018/07/31 NodeJs
详细分析React 表单与事件
2020/07/08 Javascript
vue.js+element 默认提示中英文操作
2020/11/11 Javascript
python网络编程学习笔记(八):XML生成与解析(DOM、ElementTree)
2014/06/09 Python
利用python numpy+matplotlib绘制股票k线图的方法
2019/06/26 Python
python爬虫爬取幽默笑话网站
2019/10/24 Python
django 简单实现登录验证给你
2019/11/06 Python
简单了解Python write writelines区别
2020/02/27 Python
文件上传服务器-jupyter 中python解压及压缩方式
2020/04/22 Python
CSS3制作ajax loader icon实现思路及代码
2013/08/25 HTML / CSS
Html5 实现微信分享及自定义内容的流程
2019/08/20 HTML / CSS
FILA斐乐中国官方商城:意大利运动品牌
2017/01/25 全球购物
波兰家具和室内装饰品购物网站:Vivre
2018/04/10 全球购物
如何从一个文件档案的尾端新增记录
2016/12/02 面试题
机电专业体育教师求职信
2013/09/21 职场文书
庆祝新中国成立65周年“向国旗敬礼”网上签名寄语
2014/09/27 职场文书
《自然之道》读后感3篇
2019/12/17 职场文书
利用Python多线程实现图片下载器
2022/03/25 Python