pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python多线程扫描端口示例
Jan 16 Python
从零学python系列之从文件读取和保存数据
May 23 Python
Tornado Web服务器多进程启动的2个方法
Aug 04 Python
python转换字符串为摩尔斯电码的方法
Jul 06 Python
Python导出DBF文件到Excel的方法
Jul 25 Python
Python3中使用urllib的方法详解(header,代理,超时,认证,异常处理)
Sep 21 Python
Python wxPython库使用wx.ListBox创建列表框示例
Sep 03 Python
python 将list转成字符串,中间用符号分隔的方法
Oct 23 Python
Python简易版图书管理系统
Aug 12 Python
通过 Django Pagination 实现简单分页功能
Nov 11 Python
Python如何输出百分比
Jul 31 Python
Python移位密码、仿射变换解密实例代码
Jun 27 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
一棵php的类树(支持无限分类)
2006/10/09 PHP
php 使用array函数实现分页
2015/02/13 PHP
PHP引用的调用方法分析
2016/04/25 PHP
javascript实现上传图片前的预览(TX的面试题)
2007/08/20 Javascript
对象特征检测法判断浏览器对javascript对象的支持
2009/07/25 Javascript
jquery 事件冒泡的介绍以及如何阻止事件冒泡
2012/12/25 Javascript
from 表单提交返回值用post或者是get方法实现
2013/08/21 Javascript
Js判断CSS文件加载完毕的具体实现
2014/01/17 Javascript
扒一扒JavaScript 预解释
2015/01/28 Javascript
Jquery ajax加载等待执行结束再继续执行下面代码操作
2015/11/24 Javascript
AngularJS 获取ng-repeat动态生成的ng-model值实例详解
2016/11/29 Javascript
实例分析nodejs模块xml2js解析xml过程中遇到的坑
2017/03/18 NodeJs
vue.js评论发布信息可插入QQ表情功能
2017/08/08 Javascript
Angularjs实现下拉框联动的示例代码
2017/08/22 Javascript
详解Vue.js项目API、Router配置拆分实践
2018/03/16 Javascript
React props和state属性的具体使用方法
2018/04/12 Javascript
一篇文章,教你学会Vue CLI 插件开发
2019/04/17 Javascript
NUXT SSR初级入门笔记(小结)
2019/12/16 Javascript
通过源码分析Python中的切片赋值
2017/05/08 Python
Python基于正则表达式实现文件内容替换的方法
2017/08/30 Python
Python cookbook(字符串与文本)在字符串的开头或结尾处进行文本匹配操作
2018/04/20 Python
对于Python深浅拷贝的理解
2019/07/29 Python
Python字符串格式化输出代码实例
2019/11/22 Python
python判断无向图环是否存在的示例
2019/11/22 Python
运行Python编写的程序方法实例
2020/10/21 Python
html5 viewport使用方法示例详解
2013/12/02 HTML / CSS
个人求职信范文分享
2014/01/06 职场文书
鸿星尔克广告词
2014/03/21 职场文书
篮球赛口号
2014/06/18 职场文书
敬老院标语
2014/06/27 职场文书
建筑工程造价专业自荐信
2014/07/08 职场文书
四风剖析查摆对照检查材料思想汇报
2014/09/24 职场文书
股权转让协议范本
2014/12/07 职场文书
酒店财务经理岗位职责
2015/04/08 职场文书
幼儿教师继续教育培训心得体会
2016/01/19 职场文书
python not运算符的实例用法
2021/06/30 Python