PyTorch: Softmax多分类实战操作


Posted in Python onJuly 07, 2020

多分类一种比较常用的做法是在最后一层加softmax归一化,值最大的维度所对应的位置则作为该样本对应的类。本文采用PyTorch框架,选用经典图像数据集mnist学习一波多分类。

MNIST数据集

MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

PyTorch: Softmax多分类实战操作

其中:

train-images-idx3-ubyte.gz (训练数据集图片)

train-labels-idx1-ubyte.gz (训练数据集标记类别)

t10k-images-idx3-ubyte.gz: (测试数据集)

t10k-labels-idx1-ubyte.gz(测试数据集标记类别)

PyTorch: Softmax多分类实战操作

MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示,如下图784维向量作为第一层输入特征。

PyTorch: Softmax多分类实战操作

Softmax分类

softmax函数的本质就是将一个K 维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间,并且压缩后的K个值相加等于1(变成了概率分布)。在选用Softmax做多分类时,可以根据值的大小来进行多分类的任务,如取权重最大的一维。softmax介绍和公式网上很多,这里不介绍了。下面使用Pytorch定义一个多层网络(4个隐藏层,最后一层softmax概率归一化),输出层为10正好对应10类。

PyTorch: Softmax多分类实战操作

PyTorch实战

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = nn.Linear(784, 520)
    self.l2 = nn.Linear(520, 320)
    self.l3 = nn.Linear(320, 240)
    self.l4 = nn.Linear(240, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    # Flatten the data (n, 1, 28, 28) --> (n, 784)
    x = x.view(-1, 784)
    x = F.relu(self.l1(x))
    x = F.relu(self.l2(x))
    x = F.relu(self.l3(x))
    x = F.relu(self.l4(x))
    return F.log_softmax(self.l5(x), dim=1)
    #return self.l5(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):

  # 每次输入barch_idx个数据
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)

    optimizer.zero_grad()
    output = model(data)
    # loss
    loss = F.nll_loss(output, target)
    loss.backward()
    # update
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))
def test():
  test_loss = 0
  correct = 0
  # 测试集
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target).data[0]
    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

for epoch in range(1,6):
  train(epoch)
  test()

输出结果:
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.292192
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.289466
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.294221
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.169656
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.561276

Test set: Average loss: 0.0163, Accuracy: 6698/10000 (67%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.993218
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.859608
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.499748
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.422055
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.413933

Test set: Average loss: 0.0065, Accuracy: 8797/10000 (88%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.465154
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.321842
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.187147
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.469552
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.270332

Test set: Average loss: 0.0045, Accuracy: 9137/10000 (91%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.197497
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.234830
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.260302
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.219375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.292754

Test set: Average loss: 0.0037, Accuracy: 9277/10000 (93%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 0.183354
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.207930
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.138435
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.120214
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.266199

Test set: Average loss: 0.0026, Accuracy: 9506/10000 (95%)
Process finished with exit code 0

随着训练迭代次数的增加,测试集的精确度还是有很大提高的。并且当迭代次数为5时,使用这种简单的网络可以达到95%的精确度。

以上这篇PyTorch: Softmax多分类实战操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈python中的实例方法、类方法和静态方法
Feb 17 Python
pandas数据分组和聚合操作方法
Apr 11 Python
PyQt5每天必学之工具提示功能
Apr 19 Python
matplotlib实现热成像图colorbar和极坐标图的方法
Dec 13 Python
python可视化爬虫界面之天气查询
Jul 03 Python
python 利用jinja2模板生成html代码实例
Oct 10 Python
python实现将json多行数据传入到mysql中使用
Dec 31 Python
Python pandas 列转行操作详解(类似hive中explode方法)
May 18 Python
使用keras实现非线性回归(两种加激活函数的方式)
Jul 05 Python
Django mysqlclient安装和使用详解
Sep 17 Python
python 实现控制鼠标键盘
Nov 27 Python
python用opencv 图像傅里叶变换
Jan 04 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 #Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 #Python
Python基于httpx模块实现发送请求
Jul 07 #Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 #Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 #Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
You might like
我的论坛源代码(六)
2006/10/09 PHP
CodeIgniter基本配置详细介绍
2013/11/12 PHP
Laravel 5框架学习之用户认证
2015/04/09 PHP
PHP SOCKET编程详解
2015/05/22 PHP
使用PHP similar text计算两个字符串相似度
2015/11/06 PHP
PHPStrom 新建FTP项目以及在线操作教程
2016/10/16 PHP
Laravel Validator 实现两个或多个字段联合索引唯一
2019/05/08 PHP
Jquery实现列表(隔行换色,全选,鼠标滑过当前行)效果实例
2013/06/09 Javascript
常用的几段javascript代码分享
2014/03/25 Javascript
apply和call方法定义及apply和call方法的区别
2015/11/15 Javascript
Three.js快速入门教程
2016/09/09 Javascript
jQuery实现手机版页面翻页效果的简单实例
2016/10/05 Javascript
JS两种类型的表单提交方法实例分析
2016/11/28 Javascript
JS实现的计数排序与基数排序算法示例
2017/12/04 Javascript
bootstrap 路径导航 分页 进度条的实例代码
2018/08/06 Javascript
javascript中UMD规范的代码推演
2018/08/29 Javascript
Vue2.x Todo之自定义指令实现自动聚焦的方法
2019/01/08 Javascript
vue使用Font Awesome的方法步骤
2019/02/26 Javascript
详解vue项目中调用百度地图API使用方法
2019/04/25 Javascript
pageGroup.js实现分页功能
2019/07/27 Javascript
JavaScript Tab菜单实现过程解析
2020/05/13 Javascript
vue-quill-editor插入图片路径太长问题解决方法
2021/01/08 Vue.js
react-native 实现购物车滑动删除效果的示例代码
2021/01/15 Javascript
python使用三角迭代计算圆周率PI的方法
2015/03/20 Python
python 列表降维的实例讲解
2018/06/28 Python
python制作mysql数据迁移脚本
2019/01/01 Python
python脚本第一行如何写
2020/08/30 Python
特罗佩亚包官方网站:Tropea
2017/01/03 全球购物
Big Green Smile法国:领先的英国有机和天然产品在线商店
2021/01/02 全球购物
枚举和一组预处理的#define有什么不同
2016/09/21 面试题
如何撰写一封出色的求职信
2014/04/27 职场文书
三万活动总结
2014/04/28 职场文书
我的长生果教学反思
2014/04/28 职场文书
考试作弊检讨书范文
2015/01/27 职场文书
小马王观后感
2015/06/11 职场文书
2021年pycharm的最新安装教程及基本使用图文详解
2021/04/03 Python