pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现冒泡,插入,选择排序简单实例
Aug 18 Python
Python实现的多线程http压力测试代码
Feb 08 Python
python决策树之C4.5算法详解
Dec 20 Python
Python zip()函数用法实例分析
Mar 17 Python
pandas 两列时间相减换算为秒的方法
Apr 20 Python
Python 3.3实现计算两个日期间隔秒数/天数的方法示例
Jan 07 Python
Python生成指定数量的优惠码实操内容
Jun 18 Python
keras 如何保存最佳的训练模型
May 25 Python
Python爬虫之爬取淘女郎照片示例详解
Jul 28 Python
10张动图学会python循环与递归问题
Feb 06 Python
python实现简单反弹球游戏
Apr 12 Python
解决Tkinter中button按钮未按却主动执行command函数的问题
May 23 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
PHP新手上路(七)
2006/10/09 PHP
解析php二分法查找数组是否包含某一元素
2013/05/23 PHP
js中格式化日期时间型数据函数代码
2010/11/08 Javascript
javascript整除实现代码
2010/11/23 Javascript
仿新浪微博返回顶部的jquery实现代码
2012/10/01 Javascript
JavaScript实现网页图片等比例缩放实现代码及调用方式
2013/02/25 Javascript
javascript jq 弹出层实例
2013/08/25 Javascript
js中arguments的用法(实例讲解)
2013/11/30 Javascript
解决JQeury显示内容没有边距内容紧挨着浏览器边线
2013/12/20 Javascript
基于JavaScript实现拖动滑块效果
2017/02/16 Javascript
react-native fetch的具体使用方法
2017/11/01 Javascript
vue中v-for加载本地静态图片方法
2018/03/03 Javascript
JS实现字符串中去除指定子字符串方法分析
2018/05/17 Javascript
详解Vue CLI3 多页应用实践和源码设计
2018/08/30 Javascript
vue在index.html中引入静态文件不生效问题及解决方法
2019/04/29 Javascript
基于Vue的商品主图放大镜方案详解
2019/09/19 Javascript
浅析vue中的provide / inject 有什么用处
2019/11/10 Javascript
[01:15:12]DOTA2上海特级锦标赛主赛事日 - 1 败者组第一轮#4Newbee VS CDEC
2016/03/03 DOTA
[46:16]2018DOTA2亚洲邀请赛3月30日 小组赛B组 iG VS VP
2018/03/31 DOTA
[01:05:41]EG vs Optic Supermajor 败者组 BO3 第二场 6.6
2018/06/07 DOTA
python进阶教程之函数对象(函数也是对象)
2014/08/30 Python
简述Python2与Python3的不同点
2018/01/21 Python
详解django自定义中间件处理
2018/11/21 Python
python实现把两个二维array叠加成三维array示例
2019/11/29 Python
python实现俄罗斯方块游戏(改进版)
2020/03/13 Python
Tensorflow中的图(tf.Graph)和会话(tf.Session)的实现
2020/04/22 Python
CSS3实现精美横向滚动菜单按钮
2017/04/14 HTML / CSS
html5中使用hotcss.js实现手机端自适配的方法
2020/04/23 HTML / CSS
孤独星球出版物:Lonely Planet Publications
2018/03/17 全球购物
什么是抽象
2015/12/13 面试题
大学生求职自荐信
2013/12/12 职场文书
违纪检讨书
2015/01/27 职场文书
大学生入党群众意见书
2015/06/02 职场文书
Python3.8官网文档之类的基础语法阅读
2021/09/04 Python
python模拟浏览器 使用selenium进入好友QQ空间并留言
2022/04/12 Python
Win11软件图标固定到任务栏
2022/04/19 数码科技