pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python支持断点续传的多线程下载示例
Jan 16 Python
Python的函数的一些高阶特性
Apr 27 Python
Python利用前序和中序遍历结果重建二叉树的方法
Apr 27 Python
浅析python中的分片与截断序列
Aug 09 Python
对numpy 数组和矩阵的乘法的进一步理解
Apr 04 Python
Python Django的安装配置教程图文详解
Jul 17 Python
Laravel框架表单验证格式化输出的方法
Sep 25 Python
Pytorch 实现sobel算子的卷积操作详解
Jan 10 Python
Python 常用日期处理 -- calendar 与 dateutil 模块的使用
Sep 02 Python
python获取linux系统信息的三种方法
Oct 14 Python
pytorch 一行代码查看网络参数总量的实现
May 12 Python
Python如何解决secure_filename对中文不支持问题
Jul 16 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
smarty模板中拼接字符串的方法
2014/02/14 PHP
php生成数字字母的验证码图片
2015/07/14 PHP
用js计算页面执行时间的函数
2006/12/07 Javascript
详解JavaScript的Date对象(制作简易钟表)
2020/04/07 Javascript
JavaScript中style.left与offsetLeft的使用及区别详解
2016/06/08 Javascript
js css自定义分页效果
2017/02/24 Javascript
JavaScript设计模式之单例模式详解
2017/06/09 Javascript
JS使用数组实现的队列功能示例
2019/03/04 Javascript
浅谈express.js框架中间件(middleware)
2019/04/07 Javascript
使用异步controller与jQuery实现卷帘式分页
2019/06/18 jQuery
在Vue项目中,防止页面被缩放和放大示例
2019/10/28 Javascript
javascript实现搜索筛选功能实例代码
2020/11/12 Javascript
python结合API实现即时天气信息
2016/01/19 Python
使用Python的Twisted框架编写非阻塞程序的代码示例
2016/05/25 Python
python使用KNN算法手写体识别
2018/02/01 Python
1分钟快速生成用于网页内容提取的xslt
2018/02/23 Python
win8下python3.4安装和环境配置图文教程
2018/07/31 Python
Python使用random模块生成随机数操作实例详解
2019/09/17 Python
python反转列表的三种方式解析
2019/11/08 Python
flask的orm框架SQLAlchemy查询实现解析
2019/12/12 Python
Python类class参数self原理解析
2020/11/19 Python
用Python自动清理系统垃圾的实现
2021/01/18 Python
使用简单的CSS3属性实现炫酷读者墙效果
2014/01/08 HTML / CSS
Nasty Gal英国:美国女性服饰销售网站
2021/03/02 全球购物
为什么要做架构设计
2015/07/08 面试题
Ejb技术面试题
2015/04/29 面试题
本科毕业生的求职信范文
2013/11/20 职场文书
数控机械专业个人的自我评价
2014/01/02 职场文书
小学少先队活动方案
2014/02/18 职场文书
保护环境倡议书100字
2014/05/19 职场文书
社区爱国卫生月活动总结
2014/06/30 职场文书
社会发展项目建议书
2014/08/25 职场文书
幼儿园教师个人总结
2015/02/05 职场文书
优秀团员主要事迹材料
2015/11/05 职场文书
写作之关于描写老人的好段摘抄
2019/11/14 职场文书
python获取字符串中的email
2022/03/31 Python