Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pyramid添加Middleware的方法实例
Nov 27 Python
Python下的Mysql模块MySQLdb安装详解
Apr 09 Python
python数据结构之二叉树的建立实例
Apr 29 Python
python网络编程之数据传输UDP实例分析
May 20 Python
python中matplotlib实现最小二乘法拟合的过程详解
Jul 11 Python
Python插件virtualenv搭建虚拟环境
Nov 20 Python
python通过Windows下远程控制Linux系统
Jun 20 Python
Python设计模式之代理模式实例详解
Jan 19 Python
python中sympy库求常微分方程的用法
Apr 28 Python
Python flask框架实现查询数据库并显示数据
Jun 04 Python
python两个list[]相加的实现方法
Sep 23 Python
从np.random.normal()到正态分布的拟合操作
Jun 02 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
针对初学PHP者的疑难问答(1)
2006/10/09 PHP
微信公众平台实现获取用户OpenID的方法
2015/04/15 PHP
PHP处理数组和XML之间的互相转换
2016/06/02 PHP
thinkPHP多表查询及分页功能实现方法示例
2017/07/03 PHP
PHP日期和时间函数的使用示例详解
2020/08/06 PHP
JScript中使用ADODB.Stream判断文件编码的代码
2008/06/09 Javascript
JavaScript中的作用域链和闭包
2012/06/30 Javascript
基于jquery的跟随屏幕滚动代码
2012/07/24 Javascript
jquery实现marquee效果(文字或者图片的水平垂直滚动)
2013/01/07 Javascript
用C/C++来实现 Node.js 的模块(一)
2014/09/24 Javascript
详解JS函数重载
2014/12/04 Javascript
JS实现slide文字框缩放伸展效果代码
2015/11/05 Javascript
React中上传图片到七牛的示例代码
2017/10/10 Javascript
vue2单元测试环境搭建
2018/05/24 Javascript
解决vue中使用Axios调用接口时出现的ie数据处理问题
2018/08/13 Javascript
vuejs实现折叠面板展开收缩动画效果
2018/09/06 Javascript
CSS3 动画卡顿性能优化的完美解决方案
2018/09/20 Javascript
详解Angular模板引用变量及其作用域
2018/11/23 Javascript
JQuery样式与属性设置方法分析
2019/12/07 jQuery
win7 x64系统中安装Scrapy的方法
2018/11/18 Python
pyftplib中文乱码问题解决方案
2020/01/11 Python
Python基于yield遍历多个可迭代对象
2020/03/12 Python
Python3爬虫发送请求的知识点实例
2020/07/30 Python
Python QT组件库qtwidgets的使用
2020/11/02 Python
HTML5拖放功能_动力节点Java学院整理
2017/07/13 HTML / CSS
让IE下支持Html5的placeholder属性的插件
2014/09/02 HTML / CSS
超级英雄、电影和电视、乐队和音乐T恤:Loud Clothing
2019/09/01 全球购物
某/etc/fstab文件中的某行如下: /dev/had5 /mnt/dosdata msdos defaults,usrquota 1 2 请解释其含义
2013/04/11 面试题
学校三八妇女节活动情况总结
2014/03/09 职场文书
年终晚会活动方案
2014/08/21 职场文书
关于感恩的演讲稿800字
2014/08/26 职场文书
银行竞聘报告范文
2014/11/06 职场文书
2015年爱国卫生月活动总结
2015/03/26 职场文书
关于法制教育的宣传语
2015/07/13 职场文书
500字作文之关于爸爸
2019/11/14 职场文书
十大最强电系宝可梦,阿尔宙斯电系之一,第七被称为雷神
2022/03/18 日漫