Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python生成随机数的方法
Jan 14 Python
Python学习笔记之常用函数及说明
May 23 Python
python3 pillow生成简单验证码图片的示例
Sep 19 Python
python 字符串和整数的转换方法
Jun 25 Python
python pytest进阶之fixture详解
Jun 27 Python
关于PyTorch源码解读之torchvision.models
Aug 17 Python
python遍历文件目录、批量处理同类文件
Aug 31 Python
Python队列、进程间通信、线程案例
Oct 25 Python
解决python replace函数替换无效问题
Jan 18 Python
解决python父线程关闭后子线程不关闭问题
Apr 25 Python
Python虚拟环境venv用法详解
May 25 Python
Python3中的tuple函数知识点讲解
Jan 03 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
PHP 模板高级篇总结
2006/12/21 PHP
用PHP ob_start()控制浏览器cache、生成html实现代码
2010/02/16 PHP
浅析php中三个等号(===)和两个等号(==)的区别
2013/08/06 PHP
php用ini_get获取php.ini里变量值的方法
2015/03/04 PHP
PHP版本的选择5.2.17 5.3.27 5.3.28 5.4 5.5兼容性问题分析
2016/04/04 PHP
php 判断字符串编码是utf-8 或gb2312实例
2016/11/01 PHP
php获取访问者浏览页面的浏览器类型
2017/01/23 PHP
PHP实现对数组分页处理实例详解
2017/02/07 PHP
PHP回调函数概念与用法实例分析
2017/11/03 PHP
[原创]保存的js无法执行的解决办法
2007/02/25 Javascript
jquery自定义下拉列表示例
2014/04/25 Javascript
调试JavaScript中正则表达式中遇到的问题
2015/01/27 Javascript
JS IOS/iPhone的Safari浏览器不兼容Javascript中的Date()问题如何解决
2016/11/11 Javascript
Bootstrap 设置datetimepicker在屏幕上面弹出设置方法
2017/03/21 Javascript
微信小程序实现滑动删除效果
2017/05/19 Javascript
Vue实现侧边菜单栏手风琴效果实例代码
2018/05/31 Javascript
vue 巧用过渡效果(小结)
2018/09/22 Javascript
分享Python文本生成二维码实例
2016/01/06 Python
django实现用户登陆功能详解
2017/12/11 Python
python实现k-means聚类算法
2018/02/23 Python
python实现最长公共子序列
2018/05/22 Python
分享vim python缩进等一些配置
2018/07/02 Python
Python定义二叉树及4种遍历方法实例详解
2018/07/05 Python
Python HTML解析器BeautifulSoup用法实例详解【爬虫解析器】
2019/04/05 Python
python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法
2019/06/10 Python
python3 正则表达式基础廖雪峰
2020/03/25 Python
python topk()函数求最大和最小值实例
2020/04/02 Python
使用keras时input_shape的维度表示问题说明
2020/06/29 Python
JD Sports意大利:英国篮球和运动时尚的领导者
2017/10/29 全球购物
英国文具、办公用品和科技商店:Ryman
2018/09/27 全球购物
工商管理实习自我鉴定
2013/09/28 职场文书
初一科学教学反思
2014/01/27 职场文书
数据保密承诺书
2014/06/03 职场文书
欢度春节标语
2014/07/01 职场文书
党员教师四风自我剖析材料
2014/09/30 职场文书
vue使用Google Recaptcha验证的实现示例
2021/08/23 Vue.js