Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python实现简单的多任务mysql转xml的方法
Feb 08 Python
简单谈谈python中的语句和语法
Aug 10 Python
python学习必备知识汇总
Sep 08 Python
详解Python中的Numpy、SciPy、MatPlotLib安装与配置
Nov 17 Python
实例讲解Python脚本成为Windows中运行的exe文件
Jan 24 Python
python3利用Socket实现通信的方法示例
May 06 Python
PyCharm搭建Spark开发环境实现第一个pyspark程序
Jun 13 Python
python3获取url文件大小示例代码
Sep 18 Python
解析Tensorflow之MNIST的使用
Jun 30 Python
python 如何调用远程接口
Sep 11 Python
Python中Cookies导出某站用户数据的方法
May 17 Python
解决Tkinter中button按钮未按却主动执行command函数的问题
May 23 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
便携利器 — TECSUN PL-365简评
2021/03/02 无线电
PHP&MYSQL服务器配置说明
2006/10/09 PHP
域名查询代码公布
2006/10/09 PHP
对比PHP对MySQL的缓冲查询和无缓冲查询
2016/07/01 PHP
使两个iframe的高度与内容自适应,且相等
2006/11/20 Javascript
鼠标经过的文本框textbox变色
2009/05/21 Javascript
jQuery ui1.7 dialog只能弹出一次问题
2009/08/27 Javascript
javascript数字格式化通用类 accounting.js使用
2012/08/24 Javascript
Js 时间函数getYear()的使用问题探讨
2013/04/01 Javascript
jquery获取对象的方法足以应付常见的各种类型的对象
2014/05/14 Javascript
jQuery结合HTML5制作的爱心树表白动画
2015/02/01 Javascript
Jquery实现顶部弹出框特效
2015/08/08 Javascript
Bootstrap每天必学之工具提示(Tooltip)插件
2016/04/26 Javascript
jquery表单验证实例仿Toast提示效果
2017/03/03 Javascript
vue拦截器Vue.http.interceptors.push使用详解
2017/04/22 Javascript
VUE中v-on:click事件中获取当前dom元素的代码
2018/08/22 Javascript
vue环形进度条组件实例应用
2018/10/10 Javascript
简述vue-cli中chainWebpack的使用方法
2019/07/30 Javascript
[43:32]2014 DOTA2华西杯精英邀请赛 5 25 LGD VS NewBee第一场
2014/05/26 DOTA
[01:09:19]DOTA2-DPC中国联赛 正赛 VG vs Aster BO3 第二场 2月28日
2021/03/11 DOTA
对python中raw_input()和input()的用法详解
2018/04/22 Python
tensorflow自定义激活函数实例
2020/02/04 Python
python爬取代理IP并进行有效的IP测试实现
2020/10/09 Python
css3中transition属性详解
2014/09/02 HTML / CSS
英国最大的海报商店:GB Posters
2018/03/20 全球购物
武汉某公司的C#笔试题面试题
2015/12/25 面试题
主题酒店策划书
2014/01/28 职场文书
硕士研究生求职自荐信范文
2014/03/11 职场文书
2014年国庆节演讲稿精选范文1500字
2014/09/25 职场文书
2015入党自荐书范文
2015/03/05 职场文书
力克胡哲观后感
2015/06/10 职场文书
反四风问题学习心得体会
2016/01/22 职场文书
医生行业员工的辞职信
2019/06/24 职场文书
2019消防宣传标语!
2019/07/10 职场文书
浅谈Golang 嵌套 interface 的赋值问题
2021/04/29 Golang
索尼ICF-36收音机评测
2022/04/30 无线电