pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
理解Python中的With语句
Feb 02 Python
简单介绍Python下自己编写web框架的一些要点
Apr 29 Python
在Python的Django框架中加载模版的方法
Jul 16 Python
安装ElasticSearch搜索工具并配置Python驱动的方法
Dec 22 Python
对Python3 goto 语句的使用方法详解
Feb 16 Python
详解python statistics模块及函数用法
Oct 27 Python
django实现用户注册实例讲解
Oct 30 Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 Python
python+opencv实现移动侦测(帧差法)
Mar 20 Python
python读取yaml文件后修改写入本地实例
Apr 27 Python
python boto和boto3操作bucket的示例
Oct 30 Python
python套接字socket通信
Apr 01 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
codeigniter框架The URI you submitted has disallowed characters错误解决方法
2014/05/06 PHP
PHP MVC框架skymvc支持多文件上传
2016/05/26 PHP
Symfony2针对输入时间进行查询的方法分析
2017/06/28 PHP
Laravel 5.5官方推荐的Nginx配置学习教程
2017/10/06 PHP
Yii框架的布局文件实例分析
2019/09/04 PHP
jQuery地图map悬停显示省市代码分享
2015/08/20 Javascript
javascript随机抽取0-100之间不重复的10个数
2016/02/25 Javascript
js实现固定宽高滑动轮播图效果
2017/01/13 Javascript
JavaScript代码执行的先后顺序问题
2017/10/29 Javascript
Thinkjs3新手入门之添加一个新的页面
2017/12/06 Javascript
axios中cookie跨域及相关配置示例详解
2017/12/20 Javascript
vue+elementUI动态生成面包屑导航教程
2019/11/04 Javascript
Vue解析带html标签的字符串为dom的实例
2019/11/13 Javascript
vue axios请求成功却进入catch的原因分析
2020/09/08 Javascript
实现python版本的按任意键继续/退出
2016/09/26 Python
Python八大常见排序算法定义、实现及时间消耗效率分析
2018/04/27 Python
PyTorch CNN实战之MNIST手写数字识别示例
2018/05/29 Python
利用python GDAL库读写geotiff格式的遥感影像方法
2018/11/29 Python
python实现多张图片拼接成大图
2019/01/15 Python
如何关掉pycharm中的python console(图解)
2019/10/31 Python
python 解决tqdm模块不能单行显示的问题
2020/02/19 Python
Django DRF路由与扩展功能的实现
2020/06/03 Python
python uuid生成唯一id或str的最简单案例
2021/01/13 Python
荷兰度假屋租赁网站:Aan Zee
2020/02/28 全球购物
BannerBuzz加拿大:在线定制横幅印刷、广告和标志
2020/03/10 全球购物
奖学金自我鉴定范文
2013/10/03 职场文书
工程师岗位职责规定
2014/02/26 职场文书
班主任与学生安全责任书
2014/07/25 职场文书
学习礼仪心得体会
2014/09/01 职场文书
2014年教研室工作总结
2014/12/06 职场文书
天气温馨提示语
2015/07/14 职场文书
新闻稿格式范文
2015/07/18 职场文书
公司考勤管理制度
2015/08/04 职场文书
python爬虫之利用selenium模块自动登录CSDN
2021/04/22 Python
Pycharm 如何设置HTML文件自动补全代码或标签
2021/05/21 Python
python实现学生信息管理系统(面向对象)
2022/06/05 Python