pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Nginx+uWsgi实现Python的Django框架站点动静分离
Mar 21 Python
python中list列表的高级函数
May 17 Python
python清理子进程机制剖析
Nov 23 Python
python中将字典形式的数据循环插入Excel
Jan 16 Python
mac 安装python网络请求包requests方法
Jun 13 Python
详解python3中tkinter知识点
Jun 21 Python
Selenium chrome配置代理Python版的方法
Nov 29 Python
Python分析彩票记录并预测中奖号码过程详解
Jul 09 Python
python3.7 sys模块的具体使用
Jul 22 Python
python+rsync精确同步指定格式文件
Aug 29 Python
Python爬虫代理池搭建的方法步骤
Sep 28 Python
Golang Web 框架Iris安装部署
Aug 14 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
php实现把url转换迅雷thunder资源下载地址的方法
2014/11/07 PHP
PHPStrom 新建FTP项目以及在线操作教程
2016/10/16 PHP
PHP 7.1中AES加解密方法mcrypt_module_open()的替换方案
2017/10/17 PHP
PHP信号处理机制的操作代码讲解
2019/04/19 PHP
js chrome浏览器判断代码
2010/03/28 Javascript
jQuery Ajax请求状态管理器打包
2012/05/03 Javascript
jQuery 瀑布流 绝对定位布局(二)(延迟AJAX加载图片)
2012/05/23 Javascript
JS打开新窗口的2种方式
2013/04/18 Javascript
JavaScript中this的使用详解
2013/11/08 Javascript
jQuery取得select选择的文本与值的示例
2013/12/09 Javascript
JQuery对ASP.NET MVC数据进行更新删除
2016/07/13 Javascript
JavaScript DOM 对象深入了解
2016/07/20 Javascript
微信小程序实现图片自适应(支持多图)
2017/01/25 Javascript
react-router browserHistory刷新页面404问题解决方法
2017/12/29 Javascript
详解NODEJS的http实现
2018/01/04 NodeJs
vuex的使用及持久化state的方式详解
2018/01/23 Javascript
对angularJs中自定义指令replace的属性详解
2018/10/09 Javascript
js的各种数据类型判断的介绍
2019/01/19 Javascript
你可能从未使用过的11+个JavaScript特性(小结)
2020/01/08 Javascript
python中Matplotlib实现绘制3D图的示例代码
2017/09/04 Python
对Python3.x版本print函数左右对齐详解
2018/12/22 Python
创建Django项目图文实例详解
2019/06/06 Python
详解python tkinter 图片插入问题
2020/09/03 Python
python3中数组逆序输出方法
2020/12/01 Python
css3实现冲击波效果的示例代码
2018/01/11 HTML / CSS
HTML5拖放功能_动力节点Java学院整理
2017/07/13 HTML / CSS
迪梵英国官方网站:Darphin英国
2017/12/06 全球购物
阿联酋最好的手机、电子产品和家用电器网上商店:Eros Digital Home
2020/08/09 全球购物
软件测试题目
2013/02/27 面试题
自荐信封面
2013/12/04 职场文书
大学生自我鉴定
2013/12/16 职场文书
党的群众路线教育实践活动实施方案
2014/10/31 职场文书
2015年禁毒宣传活动总结
2015/03/25 职场文书
Mysql systemctl start mysqld报错的问题解决
2021/06/03 MySQL
MySQL慢查询优化解决问题
2022/03/17 MySQL
纯CSS实现一个简单步骤条的示例代码
2022/07/15 HTML / CSS