Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈Python的Django框架中的缓存控制
Jul 24 Python
浅谈Python中列表生成式和生成器的区别
Aug 03 Python
浅析AST抽象语法树及Python代码实现
Jun 06 Python
python中json格式数据输出的简单实现方法
Oct 31 Python
python获取当前运行函数名称的方法实例代码
Apr 06 Python
python读取excel表格生成erlang数据
Aug 26 Python
一篇文章快速了解Python的GIL
Jan 12 Python
python读取csv文件并把文件放入一个list中的实例讲解
Apr 27 Python
Python中正则表达式的用法总结
Feb 22 Python
python 日期排序的实例代码
Jul 11 Python
用Python写一个自动木马程序
Sep 17 Python
详解python中*号的用法
Oct 21 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
jquery.lazyload  实现图片延迟加载jquery插件
2010/02/06 Javascript
超级酷和最实用的jQuery实例收集(20个)
2010/04/21 Javascript
基于jquery的从一个页面跳转到另一个页面的指定位置的实现代码(带平滑移动的效果)
2011/05/24 Javascript
Prototype源码浅析 String部分(二)
2012/01/16 Javascript
js判断运行jsp页面的浏览器类型以及版本示例
2013/10/30 Javascript
使用js画图之画切线
2015/01/12 Javascript
JS实现点击上移下移LI行数据的方法
2015/08/05 Javascript
javascript新闻跑马灯实例代码
2020/07/29 Javascript
jQuery插件实现表格隔行变色及鼠标滑过高亮显示效果代码
2016/02/25 Javascript
JS实现队列与堆栈的方法
2016/04/21 Javascript
JS中的==运算: [''] == false —&gt;true
2016/07/24 Javascript
用vue封装插件并发布到npm的方法步骤
2017/10/18 Javascript
vue组件发布到npm简单步骤
2017/11/30 Javascript
vue2.0 computed 计算list循环后累加值的实例
2018/03/07 Javascript
Node.js模拟发起http请求从异步转同步的5种用法
2018/09/26 Javascript
Vue父组件如何获取子组件中的变量
2019/07/24 Javascript
javascript设计模式 ? 装饰模式原理与应用实例分析
2020/04/14 Javascript
使用Python3制作TCP端口扫描器
2017/04/17 Python
matplotlib savefig 保存图片大小的实例
2018/05/24 Python
python爬取盘搜的有效链接实现代码
2019/07/20 Python
python如何实现单链表的反转
2020/02/10 Python
From CSV to SQLite3 by python 导入csv到sqlite实例
2020/02/14 Python
Python中常用的高阶函数实例详解
2020/02/21 Python
python破解同事的压缩包密码
2020/10/14 Python
Python创建文件夹与文件的快捷方法
2020/12/08 Python
宝拉珍选澳大利亚官方购物网站:Paula’s Choice澳大利亚
2016/09/13 全球购物
JSF如何进行表格处理及取值
2012/08/06 面试题
中学教师岗位职责
2013/11/26 职场文书
热爱祖国演讲稿
2014/05/04 职场文书
学校工作推荐信范文
2014/07/11 职场文书
四风查摆问题自查报告
2014/10/10 职场文书
暑假安全保证书
2015/02/28 职场文书
借款民事起诉状范文
2015/05/19 职场文书
运动会开幕式致辞
2015/07/29 职场文书
校友会致辞
2015/07/30 职场文书
Python中使用ipython的详细教程
2021/06/22 Python