pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫代理IP池实现方法
Jan 05 Python
python中类和实例如何绑定属性与方法示例详解
Aug 18 Python
简单谈谈Python的pycurl模块
Apr 07 Python
Python之文字转图片方法
May 10 Python
python实现雨滴下落到地面效果
Jun 21 Python
python 实现将txt文件多行合并为一行并将中间的空格去掉方法
Dec 20 Python
Python3 log10()函数简单用法
Feb 19 Python
Django自定义用户登录认证示例代码
Jun 30 Python
关于django 1.10 CSRF验证失败的解决方法
Aug 31 Python
Python常用编译器原理及特点解析
Mar 23 Python
python实现一个猜拳游戏
Apr 05 Python
python实现数字炸弹游戏
Jul 17 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
PHP关于IE下的iframe跨域导致session丢失问题解决方法
2013/10/10 PHP
thinkphp实现附件上传功能
2017/05/26 PHP
详解php框架Yaf路由重写
2017/06/20 PHP
PHP自定义函数实现assign()数组分配到模板及extract()变量分配到模板功能示例
2018/05/23 PHP
Laravel 5+ .env环境配置文件详解
2020/04/06 PHP
javascript入门·图片对象(无刷新变换图片)\滚动图像
2007/10/01 Javascript
JavaScript表单常用验证集合
2008/01/16 Javascript
javascript 手动给表增加数据的小例子
2013/07/10 Javascript
jquery indexOf使用方法
2013/08/19 Javascript
jquery实现加载等待效果示例
2013/09/25 Javascript
JS连连看源码完美注释版(推荐)
2013/12/09 Javascript
使用CSS3的scale实现网页整体缩放
2014/03/18 Javascript
JS实现双击编辑可修改状态的方法
2015/08/14 Javascript
jQuery实现的仿百度分页足迹效果代码
2015/10/30 Javascript
jquery实现的点击翻书效果代码
2015/11/04 Javascript
Javascript对象字面量的理解
2016/06/22 Javascript
javascript中数组(Array)对象和字符串(String)对象的常用方法总结
2016/12/15 Javascript
jQuery菜单实例(全选,反选,取消)
2017/08/28 jQuery
微信小程序实现文件、图片上传功能
2020/08/18 Javascript
一些可能会用到的Node.js面试题
2019/06/15 Javascript
JavaScript实现单图片上传并预览功能
2019/09/30 Javascript
微信小程序实现侧边分类栏
2019/10/21 Javascript
使用Python开发windows GUI程序入门实例
2014/10/23 Python
python批量导入数据进Elasticsearch的实例
2018/05/30 Python
Python logging模块用法示例
2018/08/28 Python
详解Python做一个名片管理系统
2019/03/14 Python
解决django服务器重启端口被占用的问题
2019/07/26 Python
css3中的calc函数浅析
2018/07/10 HTML / CSS
最好的商品表达自己:Cafepress
2019/09/04 全球购物
我能否用void** 指针作为参数, 使函数按引用接受一般指针
2013/02/16 面试题
2014年教师节活动总结
2014/08/29 职场文书
创先争优活动心得体会
2014/09/04 职场文书
员工手册编写范本
2015/05/14 职场文书
2019自荐信范文集锦!
2019/07/03 职场文书
导游词之吉林花园山
2019/10/17 职场文书
MySQL8.0的WITH查询详情
2021/08/30 MySQL