Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现查询苹果手机维修进度
Mar 16 Python
python基本语法练习实例
Sep 19 Python
python机器学习实战之K均值聚类
Dec 20 Python
python2.6.6如何升级到python2.7.14
Apr 08 Python
Python自动化之数据驱动让你的脚本简洁10倍【推荐】
Jun 04 Python
如何通过Python实现标签云算法
Jul 02 Python
python框架flask表单实现详解
Nov 04 Python
Python高级编程之继承问题详解(super与mro)
Nov 19 Python
如何基于python操作excel并获取内容
Dec 24 Python
Python使用循环神经网络解决文本分类问题的方法详解
Jan 16 Python
Python拼接字符串的7种方式详解
Mar 19 Python
浅谈Python中re.match()和re.search()的使用及区别
Apr 14 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
PHP 创建标签云函数代码
2010/05/26 PHP
session在php5.3中的变化 session_is_registered() is deprecated in
2013/11/12 PHP
CodeIgniter删除和设置Cookie的方法
2015/04/07 PHP
JQuery textlimit 显示用户输入的字符数 限制用户输入的字符数
2009/05/14 Javascript
js实现拉伸拖动iframe的具体代码
2013/08/03 Javascript
JavaScript的null和undefined区别示例介绍
2014/09/15 Javascript
javascript几个易错点记录
2014/11/26 Javascript
PHPMyAdmin导入时提示文件大小超出PHP限制的解决方法
2015/03/30 Javascript
深入浅析同源策略和跨域访问
2015/11/26 Javascript
详解AngularJS中$http缓存以及处理多个$http请求的方法
2016/02/06 Javascript
JS实现滑动门效果的方法详解
2016/12/19 Javascript
JS正则获取HTML元素的方法
2017/03/31 Javascript
老生常谈js-react组件生命周期
2017/05/02 Javascript
Vue.js添加组件操作示例
2018/06/13 Javascript
layui 数据表格复选框实现单选功能的例子
2019/09/19 Javascript
在Python中使用next()方法操作文件的教程
2015/05/24 Python
详解Python的Django框架中inclusion_tag的使用
2015/07/21 Python
用于业余项目的8个优秀Python库
2018/09/21 Python
python GUI编程(Tkinter) 创建子窗口及在窗口上用图片绘图实例
2020/03/04 Python
详解Python中的Lock和Rlock
2021/01/26 Python
Canvas 帧动画吃苹果小游戏
2020/08/05 HTML / CSS
国际领先的学术出版商:Springer
2017/01/11 全球购物
软件测试英文面试题
2012/10/14 面试题
团组织关系介绍信
2014/01/12 职场文书
先进工作者获奖感言
2014/02/08 职场文书
学生会主席演讲稿
2014/04/25 职场文书
会计学专业自荐信
2014/06/25 职场文书
党员对照检查材料思想汇报
2014/09/16 职场文书
整改落实自查报告
2014/11/05 职场文书
2015新年寄语(一句话)
2014/12/08 职场文书
酒店采购员岗位职责
2015/04/03 职场文书
行政撤诉申请书
2015/05/18 职场文书
2016年社会主义核心价值观心得体会
2016/01/21 职场文书
委托书范本格式
2019/04/18 职场文书
联想win10摄像头打不开怎么办?win10笔记本摄像头打不开解决办法
2022/04/08 数码科技
MySQL运行报错:“Expression #1 of SELECT list is not in GROUP BY clause and contains nonaggre”解决方法
2022/06/14 MySQL