在PyTorch中使用标签平滑正则化的问题


Posted in Python onApril 03, 2020

什么是标签平滑?在PyTorch中如何去使用它?

在训练深度学习模型的过程中,过拟合和概率校准(probability calibration)是两个常见的问题。一方面,正则化技术可以解决过拟合问题,其中较为常见的方法有将权重调小,迭代提前停止以及丢弃一些权重等。另一方面,Platt标度法和isotonic regression法能够对模型进行校准。但是有没有一种方法可以同时解决过拟合和模型过度自信呢?

标签平滑也许可以。它是一种去改变目标变量的正则化技术,能使模型的预测结果不再仅为一个确定值。标签平滑之所以被看作是一种正则化技术,是因为它可以防止输入到softmax函数的最大logits值变得特别大,从而使得分类模型变得更加准确。

在这篇文章中,我们定义了标签平滑化,在测试过程中我们将它应用到交叉熵损失函数中。

标签平滑?

假设这里有一个多分类问题,在这个问题中,目标变量通常是一个one-hot向量,即当处于正确分类时结果为1,否则结果是0。

标签平滑改变了目标向量的最小值,使它为ε。因此,当模型进行分类时,其结果不再仅是1或0,而是我们所要求的1-ε和ε,从而带标签平滑的交叉熵损失函数为如下公式。

在PyTorch中使用标签平滑正则化的问题

在这个公式中,ce(x)表示x的标准交叉熵损失函数,例如:-log(p(x)),ε是一个非常小的正数,i表示对应的正确分类,N为所有分类的数量。

直观上看,标记平滑限制了正确类的logit值,并使得它更接近于其他类的logit值。从而在一定程度上,它被当作为一种正则化技术和一种对抗模型过度自信的方法。

PyTorch中的使用

在PyTorch中,带标签平滑的交叉熵损失函数实现起来非常简单。首先,让我们使用一个辅助函数来计算两个值之间的线性组合。

deflinear_combination(x, y, epsilon):return epsilon*x + (1-epsilon)*y

下一步,我们使用PyTorch中一个全新的损失函数:nn.Module.

import torch.nn.functional as F
defreduce_loss(loss, reduction='mean'):return loss.mean() if reduction=='mean'else loss.sum() if reduction=='sum'else loss
classLabelSmoothingCrossEntropy(nn.Module):def__init__(self, epsilon:float=0.1, reduction='mean'):
    super().__init__()
    self.epsilon = epsilon
    self.reduction = reduction

  defforward(self, preds, target):
    n = preds.size()[-1]
    log_preds = F.log_softmax(preds, dim=-1)
    loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
    nll = F.nll_loss(log_preds, target, reduction=self.reduction)
    return linear_combination(loss/n, nll, self.epsilon)

我们现在可以在代码中删除这个类。对于这个例子,我们使用标准的fast.ai pets example.

from fastai.vision import *
from fastai.metrics import error_rate
# prepare the data
path = untar_data(URLs.PETS)
path_img = path/'images'
fnames = get_image_files(path_img)
bs = 64
np.random.seed(2)
pat = r'/([^/]+)_\d+.jpg$'
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs) \
           .normalize(imagenet_stats)
# train the model
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.loss_func = LabelSmoothingCrossEntropy()
learn.fit_one_cycle(4)

最后将数据转换成模型可以使用的格式,选择ResNet架构并以带标签平滑的交叉熵损失函数作为优化目标。经过四轮循环后,其结果如下

在PyTorch中使用标签平滑正则化的问题

我们所得结果的错误率仅为7.5%,这对于10行左右的代码来说是完全可以接受的,并且在模型中大多数参数还都选择的是默认设置。

因此,在模型中还有许多参数可以进行调整,从而使得模型的表现性能更好,例如:可以使用不同的优化器、超参数、模型架构等。

结论

在这篇文章中,我们了解了什么是标签平滑以及什么时候去使用它,并且我们还知道了如何在PyTorch中实现它。之后,我们训练了一个先进的计算机视觉模型,仅使用十行代码就识别出了不同品种的猫和狗。

模型正则化和模型校准是两个重要的概念。若想成为一个深度学习的资深玩家,就应该好好地去理解这些能够对抗过拟合和模型过度自信的工具。

作者简介: Dimitris Poulopoulos,是BigDataStack的一名机器学习研究员,同时也是希腊Piraeus大学的博士。曾为欧盟委员会、欧盟统计局、国际货币基金组织、欧洲央行等客户设计过与AI相关的软件。

总结

到此这篇关于如何在PyTorch中使用标签平滑正则化的文章就介绍到这了,更多相关PyTorch正则化内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python数据结构之二叉树的遍历实例
Apr 29 Python
Python 描述符(Descriptor)入门
Nov 20 Python
Python 专题五 列表基础知识(二维list排序、获取下标和处理txt文本实例)
Mar 20 Python
python制作小说爬虫实录
Aug 14 Python
python3.7.0的安装步骤
Aug 27 Python
python 从文件夹抽取图片另存的方法
Dec 04 Python
Python语言检测模块langid和langdetect的使用实例
Feb 19 Python
pyqt5 tablewidget 利用线程动态刷新数据的方法
Jun 17 Python
用Python实现BP神经网络(附代码)
Jul 10 Python
python flask几分钟实现web服务的例子
Jul 26 Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 Python
Window10上Tensorflow的安装(CPU和GPU版本)
Dec 15 Python
pip install 使用国内镜像的方法示例
Apr 03 #Python
pycharm解决关闭flask后依旧可以访问服务的问题
Apr 03 #Python
Windows+Anaconda3+PyTorch+PyCharm的安装教程图文详解
Apr 03 #Python
基于python图像处理API的使用示例
Apr 03 #Python
解决json中ensure_ascii=False的问题
Apr 03 #Python
基于Python的OCR实现示例
Apr 03 #Python
Python %r和%s区别代码实例解析
Apr 03 #Python
You might like
php实现微信原生支付(扫码支付)功能
2018/05/30 PHP
HTML 自动伸缩的表格Table js实现
2009/04/01 Javascript
js 禁止选择功能实现代码(兼容IE/Firefox)
2010/04/23 Javascript
基于Jquery的动态添加控件并取值的实现代码
2010/09/24 Javascript
从盛大通行证上摘下来的身份证验证js代码
2011/01/11 Javascript
情人节之礼 js项链效果
2012/02/13 Javascript
JQuery插件Style定制化方法的分析与比较
2012/05/03 Javascript
JS限制文本框只能输入数字和字母方法
2015/02/28 Javascript
JS数组排序技巧汇总(冒泡、sort、快速、希尔等排序)
2015/11/24 Javascript
基于AngularJS实现的工资计算器实例
2017/06/16 Javascript
详解如何让InstantClick兼容MathJax、百度统计等
2017/09/12 Javascript
基于Vue.js与WordPress Rest API构建单页应用详解
2019/09/16 Javascript
小程序实现上下移动切换位置
2019/09/23 Javascript
微信小程序select下拉框实现源码
2019/11/08 Javascript
python通过urllib2爬网页上种子下载示例
2014/02/24 Python
Python中为什么要用self探讨
2015/04/14 Python
numpy.random.seed()的使用实例解析
2018/02/03 Python
对numpy中shape的深入理解
2018/06/15 Python
详解Django+Uwsgi+Nginx的生产环境部署
2018/06/25 Python
python manage.py runserver流程解析
2019/11/08 Python
Python Django中的STATIC_URL 设置和使用方式
2020/03/27 Python
Python引入多个模块及包的概念过程解析
2020/09/21 Python
Python实现PS滤镜中的USM锐化效果
2020/12/04 Python
CSS Houdini实现动态波浪纹效果
2019/07/30 HTML / CSS
HTML5自定义属性前缀data-及dataset的使用方法(html5 新特性)
2017/08/24 HTML / CSS
伦敦一家非常流行的时尚精品店:Oxygen Boutique
2017/01/15 全球购物
阿玛尼美妆俄罗斯官网:Giorgio Armani Beauty RU
2020/07/19 全球购物
同步和异步有何异同,在什么情况下分别使用他们?举例说明
2014/02/27 面试题
如何设置Java的运行环境
2013/04/05 面试题
SQL Server面试题
2016/10/17 面试题
春季运动会广播稿大全
2014/02/19 职场文书
大学团日活动新闻稿
2014/09/10 职场文书
党员自我剖析材料范文
2014/10/06 职场文书
公证处委托书
2015/01/28 职场文书
《清澈的湖水》教学反思
2016/02/17 职场文书
Python网络编程之ZeroMQ知识总结
2021/04/25 Python