Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python转换摩斯密码示例
Feb 16 Python
用python登录Dr.com思路以及代码分享
Jun 25 Python
12步教你理解Python装饰器
Feb 25 Python
利用Python如何实现数据驱动的接口自动化测试
May 11 Python
tensorflow 加载部分变量的实例讲解
Jul 27 Python
python2与python3的print及字符串格式化小结
Nov 30 Python
使用 Visual Studio Code(VSCode)搭建简单的Python+Django开发环境的方法步骤
Dec 17 Python
对Python3中dict.keys()转换成list类型的方法详解
Feb 03 Python
Python生成随机验证码代码实例解析
Jun 09 Python
详解python的变量缓存机制
Jan 24 Python
提取视频中的音频 Python只需要三行代码!
May 10 Python
一篇文章弄懂Python关键字、标识符和变量
Jul 15 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
让PHP COOKIE立即生效,不用刷新就可以使用
2011/03/09 PHP
PHP数据库处理封装类实例
2016/12/24 PHP
JavaScript基本对象
2007/01/11 Javascript
JavaScript 对象模型 执行模型
2010/10/15 Javascript
jquery聚焦文本框与扩展文本框聚焦方法
2012/10/12 Javascript
getJSON调用后台json数据时函数被调用两次的原因猜想
2013/09/29 Javascript
简述AngularJS相关的一些编程思想
2015/06/23 Javascript
js实现的星星评分功能函数
2015/12/09 Javascript
Javascript字符串常用方法详解
2016/07/21 Javascript
jQuery删除当前节点元素
2016/12/07 Javascript
JavaScript数据结构之栈实例用法
2019/01/18 Javascript
小程序中canvas的drawImage方法参数使用详解
2019/07/04 Javascript
javascript事件监听与事件委托实例详解
2019/08/16 Javascript
关于JS解构的5种有趣用法
2019/09/05 Javascript
layer ui 导入文件之前传入数据的实例
2019/09/23 Javascript
javascript设计模式 ? 策略模式原理与用法实例分析
2020/04/21 Javascript
js瀑布流布局的实现
2020/06/28 Javascript
Vue全局使用less样式,组件使用全局样式文件中定义的变量操作
2020/10/21 Javascript
[01:03:36]DOTA2-DPC中国联赛 正赛 VG vs Magma BO3 第二场 1月26日
2021/03/11 DOTA
Django开发中的日志输出的方法
2018/07/02 Python
ZABBIX3.2使用python脚本实现监控报表的方法
2019/07/02 Python
Python 用三行代码提取PDF表格数据
2019/10/13 Python
TensorFlow使用Graph的基本操作的实现
2020/04/22 Python
VScode连接远程服务器上的jupyter notebook的实现
2020/04/23 Python
Python爬虫scrapy框架Cookie池(微博Cookie池)的使用
2021/01/13 Python
支持IE8的纯css3开发的响应式设计动画菜单教程
2014/11/05 HTML / CSS
台湾时尚彩瞳专门店:imeime
2019/08/16 全球购物
是否可以从一个static方法内部发出对非static方法的调用?
2014/08/18 面试题
挂科检讨书范文
2014/02/20 职场文书
机关单位动员会主持词
2014/03/20 职场文书
一份关于丢失公司财物的检讨书
2014/09/19 职场文书
不同意离婚代理词
2015/05/23 职场文书
幼儿体育课教学反思
2016/02/16 职场文书
MySQL infobright的安装步骤
2021/04/07 MySQL
常用的MongoDB查询语句的示例代码
2021/07/25 MongoDB
windows系统安装配置nginx环境
2022/06/28 Servers