pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
比较详细Python正则表达式操作指南(re使用)
Sep 06 Python
python 生成目录树及显示文件大小的代码
Jul 23 Python
python通过ftplib登录到ftp服务器的方法
May 08 Python
Python3.5编程实现修改IIS WEB.CONFIG的方法示例
Aug 18 Python
基于python中pygame模块的Linux下安装过程(详解)
Nov 09 Python
TF-IDF算法解析与Python实现方法详解
Nov 16 Python
使用python实现knn算法
Dec 20 Python
Python面向对象编程之继承与多态详解
Jan 16 Python
python 中的list和array的不同之处及转换问题
Mar 13 Python
python实现批量修改服务器密码的方法
Aug 13 Python
Pytorch抽取网络层的Feature Map(Vgg)实例
Aug 20 Python
DRF使用simple JWT身份验证的实现
Jan 14 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
同时提取多条新闻中的文本一例
2006/10/09 PHP
PHP单例模式定义与使用实例详解
2017/02/06 PHP
PHP操作Redis数据库常用方法示例
2018/08/25 PHP
PHP实现文件上传操作和封装
2020/03/04 PHP
JS 日期验证正则附asp日期格式化函数
2009/09/11 Javascript
JS特权方法定义作用以及与公有方法的区别
2013/03/18 Javascript
javascript中的nextSibling使用陷(da)阱(keng)
2014/05/05 Javascript
AspNet中使用JQuery上传插件Uploadify详解
2015/05/20 Javascript
Atitit.js的键盘按键事件捆绑and事件调度
2016/04/01 Javascript
深入理解node exports和module.exports区别
2016/06/01 Javascript
JavaScript函数中关于valueOf和toString的理解
2016/06/14 Javascript
微信小程序 sha1 实现密码加密实例详解
2017/07/06 Javascript
基于jQuery中ajax的相关方法汇总(必看篇)
2017/11/08 jQuery
使用vue如何构建一个自动建站项目
2018/02/05 Javascript
angularjs select 赋值 ng-options配置方法
2018/02/28 Javascript
webpack实现一个行内样式px转vw的loader示例
2018/09/13 Javascript
vue下axios拦截器token刷新机制的实例代码
2020/01/17 Javascript
OpenLayers3实现测量功能
2020/09/25 Javascript
vue实现广告栏上下滚动效果
2020/11/26 Vue.js
pyramid配置session的方法教程
2013/11/27 Python
python网络编程示例(客户端与服务端)
2014/04/24 Python
Python使用函数默认值实现函数静态变量的方法
2014/08/18 Python
Python实现对PPT文件进行截图操作的方法
2015/04/28 Python
python PyTorch参数初始化和Finetune
2018/02/11 Python
HTML5播放实现rtmp流直播
2020/06/16 HTML / CSS
size?爱尔兰官方网站:英国伦敦的球鞋精品店
2019/03/31 全球购物
荷兰音乐会和音乐剧门票订购网站:Topticketshop
2019/08/27 全球购物
澳大利亚手袋、珠宝和在线时尚精品店:The Way
2019/12/21 全球购物
JAVA和C++的区别
2013/10/06 面试题
餐饮业经理竞聘演讲稿
2014/01/14 职场文书
大学生实习感言
2014/01/16 职场文书
暑假家长评语大全
2014/04/17 职场文书
高中英语演讲稿范文
2014/04/24 职场文书
2015年学校团委工作总结
2015/05/26 职场文书
python for循环赋值问题
2021/06/03 Python
HTML5中的DOCUMENT.VISIBILITYSTATE属性详解
2023/05/07 HTML / CSS