PyTorch: Softmax多分类实战操作


Posted in Python onJuly 07, 2020

多分类一种比较常用的做法是在最后一层加softmax归一化,值最大的维度所对应的位置则作为该样本对应的类。本文采用PyTorch框架,选用经典图像数据集mnist学习一波多分类。

MNIST数据集

MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

PyTorch: Softmax多分类实战操作

其中:

train-images-idx3-ubyte.gz (训练数据集图片)

train-labels-idx1-ubyte.gz (训练数据集标记类别)

t10k-images-idx3-ubyte.gz: (测试数据集)

t10k-labels-idx1-ubyte.gz(测试数据集标记类别)

PyTorch: Softmax多分类实战操作

MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示,如下图784维向量作为第一层输入特征。

PyTorch: Softmax多分类实战操作

Softmax分类

softmax函数的本质就是将一个K 维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间,并且压缩后的K个值相加等于1(变成了概率分布)。在选用Softmax做多分类时,可以根据值的大小来进行多分类的任务,如取权重最大的一维。softmax介绍和公式网上很多,这里不介绍了。下面使用Pytorch定义一个多层网络(4个隐藏层,最后一层softmax概率归一化),输出层为10正好对应10类。

PyTorch: Softmax多分类实战操作

PyTorch实战

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = nn.Linear(784, 520)
    self.l2 = nn.Linear(520, 320)
    self.l3 = nn.Linear(320, 240)
    self.l4 = nn.Linear(240, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    # Flatten the data (n, 1, 28, 28) --> (n, 784)
    x = x.view(-1, 784)
    x = F.relu(self.l1(x))
    x = F.relu(self.l2(x))
    x = F.relu(self.l3(x))
    x = F.relu(self.l4(x))
    return F.log_softmax(self.l5(x), dim=1)
    #return self.l5(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):

  # 每次输入barch_idx个数据
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)

    optimizer.zero_grad()
    output = model(data)
    # loss
    loss = F.nll_loss(output, target)
    loss.backward()
    # update
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))
def test():
  test_loss = 0
  correct = 0
  # 测试集
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target).data[0]
    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

for epoch in range(1,6):
  train(epoch)
  test()

输出结果:
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.292192
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.289466
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.294221
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.169656
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.561276

Test set: Average loss: 0.0163, Accuracy: 6698/10000 (67%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.993218
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.859608
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.499748
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.422055
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.413933

Test set: Average loss: 0.0065, Accuracy: 8797/10000 (88%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.465154
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.321842
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.187147
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.469552
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.270332

Test set: Average loss: 0.0045, Accuracy: 9137/10000 (91%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.197497
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.234830
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.260302
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.219375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.292754

Test set: Average loss: 0.0037, Accuracy: 9277/10000 (93%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 0.183354
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.207930
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.138435
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.120214
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.266199

Test set: Average loss: 0.0026, Accuracy: 9506/10000 (95%)
Process finished with exit code 0

随着训练迭代次数的增加,测试集的精确度还是有很大提高的。并且当迭代次数为5时,使用这种简单的网络可以达到95%的精确度。

以上这篇PyTorch: Softmax多分类实战操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python 获取进程pid号的方法
Mar 10 Python
Python常见工厂函数用法示例
Mar 21 Python
Python爬虫之pandas基本安装与使用方法示例
Aug 08 Python
python读取文本中的坐标方法
Oct 14 Python
Python实现简单的列表冒泡排序和反转列表操作示例
Jul 10 Python
python实现网站微信登录的示例代码
Sep 18 Python
关于Flask项目无法使用公网IP访问的解决方式
Nov 19 Python
Python基于requests库爬取网站信息
Mar 02 Python
如何实现在jupyter notebook中播放视频(不停地展示图片)
Apr 23 Python
Python 捕获代码中所有异常的方法
Aug 03 Python
python 利用matplotlib在3D空间中绘制平面的案例
Feb 06 Python
Python页面加载的等待方式总结
Feb 28 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 #Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 #Python
Python基于httpx模块实现发送请求
Jul 07 #Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 #Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 #Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
You might like
PHP通用检测函数集合
2006/11/25 PHP
php图片验证码代码
2008/03/27 PHP
php中获取远程客户端的真实ip地址的方法
2011/08/03 PHP
基于php权限分配的实现代码
2013/04/28 PHP
解析php 版获取重定向后的地址(代码)
2013/06/26 PHP
php使用正则过滤js脚本代码实例
2014/05/10 PHP
JavaScript使用技巧精萃[代码非常实用]
2008/11/21 Javascript
Javascript load Page,load css,load js实现代码
2010/03/31 Javascript
非jQuery实现照片散落桌子上,单击放大的LightBox效果
2014/11/28 Javascript
基于jQuery实现左右图片轮播(原理通用)
2015/12/24 Javascript
学习JavaScript鼠标响应事件
2015/12/25 Javascript
基于jQuery实现返回顶部实例代码
2016/01/01 Javascript
JavaScript中对JSON对象的基本操作示例
2016/05/21 Javascript
AngularJs directive详解及示例代码
2016/09/01 Javascript
利用n工具轻松管理Node.js的版本
2017/04/21 Javascript
js中的闭包学习心得
2018/02/06 Javascript
转换layUI的数据表格中的日期格式方法
2019/09/19 Javascript
js的Object.assign用法示例分析
2020/03/05 Javascript
详解ES6 中的Object.assign()的用法实例代码
2021/01/11 Javascript
python抓取豆瓣图片并自动保存示例学习
2014/01/10 Python
python基于itchat实现微信群消息同步机器人
2017/02/27 Python
解决Python网页爬虫之中文乱码问题
2018/05/11 Python
python实现对列表中的元素进行倒序打印
2019/11/23 Python
python读取图像矩阵文件并转换为向量实例
2020/06/18 Python
Python中使用aiohttp模拟服务器出现错误问题及解决方法
2020/10/31 Python
改变生活的男士内衣:SAXX Underwear
2019/08/28 全球购物
交通法规咨询中心工作职责
2013/11/27 职场文书
销售找工作求职信
2013/12/20 职场文书
理工学院学生自我鉴定
2014/02/23 职场文书
党员岗位承诺口号大全
2014/03/28 职场文书
大学生入党推荐书范文
2014/05/17 职场文书
4s店销售经理岗位职责
2014/07/19 职场文书
2019奶茶店创业计划书范本!
2019/07/15 职场文书
Java后台生成图片的完整步骤
2021/08/04 Java/Android
剧场版《转生恶役只好拔除破灭旗标》公开最新视觉图 2023年上映
2022/04/02 日漫
Ruby处理CSV数据方法详解
2022/04/18 Ruby