pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中处理时间之clock()方法的使用
May 22 Python
Python编程实现数学运算求一元二次方程的实根算法示例
Apr 02 Python
python中hashlib模块用法示例
Oct 30 Python
Python+pandas计算数据相关系数的实例
Jul 03 Python
python 文件转成16进制数组的实例
Jul 09 Python
pycharm运行程序时在Python console窗口中运行的方法
Dec 03 Python
python matplotlib画图库学习绘制常用的图
Mar 19 Python
python卸载后再次安装遇到的问题解决
Jul 10 Python
tensorflow自定义激活函数实例
Feb 04 Python
Python多进程multiprocessing、进程池用法实例分析
Mar 24 Python
python 实现学生信息管理系统的示例
Nov 28 Python
教你怎么用Python监控愉客行车程
Apr 29 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
Cakephp 执行主要流程
2010/03/24 PHP
给初学者的30条PHP最佳实践(荒野无灯)
2011/08/02 PHP
Zend studio文件注释模板设置方法
2013/09/29 PHP
php发送html格式文本邮件的方法
2015/06/10 PHP
修复ShopNC使用QQ 互联时提示100010 错误
2015/11/08 PHP
php外部执行命令函数用法小结
2016/10/11 PHP
jquery remove方法应用详解
2012/11/22 Javascript
禁用页面部分JavaScript方法的具体实现
2013/07/31 Javascript
JavaScript中的原型链prototype介绍
2014/12/30 Javascript
javascript实现图像循环明暗变化的方法
2015/02/25 Javascript
js实现文件上传表单域美化特效
2015/11/02 Javascript
微信小程序 网络请求(post请求,get请求)
2017/01/17 Javascript
原生JS实现左右箭头选择日期实例代码
2017/03/14 Javascript
Mongoose经常返回e11000 error的原因分析
2017/03/29 Javascript
Angular实现的日程表功能【可添加及隐藏显示内容】
2017/12/27 Javascript
vue组件中使用props传递数据的实例详解
2018/04/08 Javascript
vue 项目build错误异常的解决方法
2019/04/22 Javascript
解析Python中的异常处理
2015/04/28 Python
利用Python开发实现简单的记事本
2016/11/15 Python
Python 多线程实例详解
2017/03/25 Python
django定期执行任务(实例讲解)
2017/11/03 Python
python之pandas用法大全
2018/03/13 Python
Numpy中转置transpose、T和swapaxes的实例讲解
2018/04/17 Python
Python cookbook(数据结构与算法)将多个映射合并为单个映射的方法
2018/04/19 Python
python日期时间转为字符串或者格式化输出的实例
2018/05/29 Python
Python PyAutoGUI模块控制鼠标和键盘实现自动化任务详解
2018/09/04 Python
Python redis操作实例分析【连接、管道、发布和订阅等】
2019/05/16 Python
简单了解python调用其他脚本方法实例
2020/03/26 Python
Python批量处理csv并保存过程解析
2020/05/16 Python
幼儿园庆六一游园活动方案
2014/01/29 职场文书
英文版辞职信
2015/02/28 职场文书
合作与交流自我评价
2015/03/09 职场文书
拉贝日记观后感
2015/06/05 职场文书
土木工程生产实习心得体会
2016/01/22 职场文书
2016国庆促销广告语
2016/01/28 职场文书
CSS 伪元素::marker详解
2021/06/26 HTML / CSS