PyTorch: Softmax多分类实战操作


Posted in Python onJuly 07, 2020

多分类一种比较常用的做法是在最后一层加softmax归一化,值最大的维度所对应的位置则作为该样本对应的类。本文采用PyTorch框架,选用经典图像数据集mnist学习一波多分类。

MNIST数据集

MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

PyTorch: Softmax多分类实战操作

其中:

train-images-idx3-ubyte.gz (训练数据集图片)

train-labels-idx1-ubyte.gz (训练数据集标记类别)

t10k-images-idx3-ubyte.gz: (测试数据集)

t10k-labels-idx1-ubyte.gz(测试数据集标记类别)

PyTorch: Softmax多分类实战操作

MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示,如下图784维向量作为第一层输入特征。

PyTorch: Softmax多分类实战操作

Softmax分类

softmax函数的本质就是将一个K 维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间,并且压缩后的K个值相加等于1(变成了概率分布)。在选用Softmax做多分类时,可以根据值的大小来进行多分类的任务,如取权重最大的一维。softmax介绍和公式网上很多,这里不介绍了。下面使用Pytorch定义一个多层网络(4个隐藏层,最后一层softmax概率归一化),输出层为10正好对应10类。

PyTorch: Softmax多分类实战操作

PyTorch实战

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = nn.Linear(784, 520)
    self.l2 = nn.Linear(520, 320)
    self.l3 = nn.Linear(320, 240)
    self.l4 = nn.Linear(240, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    # Flatten the data (n, 1, 28, 28) --> (n, 784)
    x = x.view(-1, 784)
    x = F.relu(self.l1(x))
    x = F.relu(self.l2(x))
    x = F.relu(self.l3(x))
    x = F.relu(self.l4(x))
    return F.log_softmax(self.l5(x), dim=1)
    #return self.l5(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):

  # 每次输入barch_idx个数据
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)

    optimizer.zero_grad()
    output = model(data)
    # loss
    loss = F.nll_loss(output, target)
    loss.backward()
    # update
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))
def test():
  test_loss = 0
  correct = 0
  # 测试集
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target).data[0]
    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

for epoch in range(1,6):
  train(epoch)
  test()

输出结果:
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.292192
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.289466
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.294221
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.169656
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.561276

Test set: Average loss: 0.0163, Accuracy: 6698/10000 (67%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.993218
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.859608
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.499748
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.422055
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.413933

Test set: Average loss: 0.0065, Accuracy: 8797/10000 (88%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.465154
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.321842
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.187147
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.469552
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.270332

Test set: Average loss: 0.0045, Accuracy: 9137/10000 (91%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.197497
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.234830
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.260302
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.219375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.292754

Test set: Average loss: 0.0037, Accuracy: 9277/10000 (93%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 0.183354
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.207930
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.138435
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.120214
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.266199

Test set: Average loss: 0.0026, Accuracy: 9506/10000 (95%)
Process finished with exit code 0

随着训练迭代次数的增加,测试集的精确度还是有很大提高的。并且当迭代次数为5时,使用这种简单的网络可以达到95%的精确度。

以上这篇PyTorch: Softmax多分类实战操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
easy_install python包安装管理工具介绍
Feb 10 Python
Python中在脚本中引用其他文件函数的实现方法
Jun 23 Python
python解决汉字编码问题:Unicode Decode Error
Jan 19 Python
Python实现可设置持续运行时间、线程数及时间间隔的多线程异步post请求功能
Jan 11 Python
Python with语句上下文管理器两种实现方法分析
Feb 09 Python
python实现寻找最长回文子序列的方法
Jun 02 Python
python排序函数sort()与sorted()的区别
Sep 18 Python
python项目对接钉钉SDK的实现
Jul 15 Python
python实现一个函数版的名片管理系统过程解析
Aug 27 Python
python3实现微型的web服务器
Sep 03 Python
Django使用Celery加redis执行异步任务的实例内容
Feb 20 Python
Keras设置以及获取权重的实现
Jun 19 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 #Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 #Python
Python基于httpx模块实现发送请求
Jul 07 #Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 #Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 #Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
You might like
php处理restful请求的路由类分享
2014/02/27 PHP
PHP static局部静态变量和全局静态变量总结
2014/03/02 PHP
关于Yii中模型场景的一些简单介绍
2019/09/22 PHP
Vagrant(WSL)+PHPStorm+Xdebu 断点调试环境搭建
2019/12/13 PHP
JavaScript数据结构与算法之栈详解
2015/03/12 Javascript
基于Css3和JQuery实现打字机效果
2015/08/11 Javascript
Jquery检验手机号是否符合规则并根据手机号检测结果将提交按钮设为不同状态
2015/11/26 Javascript
jQuery实现批量判断表单中文本框非空的方法(2种方法)
2015/12/09 Javascript
bootstrap table实现点击翻页功能 可记录上下页选中的行
2017/09/28 Javascript
vue中h5端打开app(判断是安卓还是苹果)
2021/02/26 Vue.js
[01:31](回顾)杀出重围,决战TI之巅
2014/07/01 DOTA
Python和perl实现批量对目录下电子书文件重命名的代码分享
2014/11/21 Python
Python中input与raw_input 之间的比较
2017/08/20 Python
利用Tkinter(python3.6)实现一个简单计算器
2017/12/21 Python
Python编程求解二叉树中和为某一值的路径代码示例
2018/01/04 Python
解决python升级引起的pip执行错误的问题
2018/06/12 Python
python 使用pandas计算累积求和的方法
2019/02/08 Python
python实现扫描ip地址的小程序
2019/04/16 Python
详解如何从TensorFlow的mnist数据集导出手写体数字图片
2019/08/05 Python
Python DataFrame一列拆成多列以及一行拆成多行
2019/08/06 Python
Python Django框架url反向解析实现动态生成对应的url链接示例
2019/10/18 Python
基于django2.2连oracle11g解决版本冲突的问题
2020/07/02 Python
深入了解Python 方法之类方法 & 静态方法
2020/08/17 Python
HTML5 语义化结构化规范化
2008/10/17 HTML / CSS
Shopee马来西亚:随拍即卖,最佳行动电商拍卖平台
2017/06/05 全球购物
天逸系统(武汉)有限公司Java笔试题
2015/12/29 面试题
销售辞职报告范文
2014/01/12 职场文书
高中军训感言400字
2014/02/24 职场文书
公司法人授权委托书范本
2014/09/12 职场文书
秋冬农业生产标语
2014/10/09 职场文书
杜甫草堂导游词
2015/02/03 职场文书
教师岗位职责
2015/02/03 职场文书
初中教师个人总结
2015/02/10 职场文书
上诉状格式
2015/05/23 职场文书
2016年教师新年寄语
2015/08/18 职场文书
django项目、vue项目部署云服务器的详细过程
2022/07/23 Servers