Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 中的列表解析和生成表达式
Mar 10 Python
python使用Queue在多个子进程间交换数据的方法
Apr 18 Python
python实现实时监控文件的方法
Aug 26 Python
python中使用正则表达式的连接符示例代码
Oct 10 Python
python爬虫使用cookie登录详解
Dec 27 Python
详解Python连接MySQL数据库的多种方式
Apr 16 Python
Python中print和return的作用及区别解析
May 05 Python
Python如何基于selenium实现自动登录博客园
Dec 16 Python
Pycharm激活码激活两种快速方式(附最新激活码和插件)
Mar 12 Python
python 3.8.3 安装配置图文教程
May 21 Python
python代码能做成软件吗
Jul 24 Python
python自动化办公操作PPT的实现
Feb 05 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
php 表单数据的获取代码
2009/03/10 PHP
php下封装较好的数字分页方法
2010/11/23 PHP
PHP 时间日期操作实战
2011/08/26 PHP
php实现批量下载百度云盘文件例子分享
2014/04/10 PHP
浅析PHP数据导出知识点
2018/02/17 PHP
php实现获取近几日、月时间示例
2019/07/06 PHP
用jquery ajax获取网站Alexa排名的代码
2009/12/12 Javascript
JavaScript之Getters和Setters 平台支持等详细介绍
2012/12/07 Javascript
为何JS操作的href都是javascript:void(0);呢
2015/11/12 Javascript
jquery自定义表格样式
2015/11/23 Javascript
JQuery实现网页右侧随动广告特效
2016/01/17 Javascript
jQuery实现的鼠标经过时变宽的效果(附demo源码)
2016/04/28 Javascript
Javascript中for循环语句的几种写法总结对比
2017/01/23 Javascript
使用Math.max,Math.min获取数组中的最值实例
2017/04/25 Javascript
jQuery实现简单的回到顶部totop功能示例
2017/10/16 jQuery
关于自定义Egg.js的请求级别日志详解
2018/12/12 Javascript
VUE简单的定时器实时刷新的实现方法
2019/01/20 Javascript
Typescript的三种运行方式(小结)
2019/09/18 Javascript
小程序怎样让wx.navigateBack更好用的方法实现
2019/11/01 Javascript
element表格翻页第2页从1开始编号(后端从0开始分页)
2019/12/10 Javascript
Element Carousel 走马灯的具体实现
2020/07/26 Javascript
VScode编写第一个Python程序HelloWorld步骤
2018/04/06 Python
pandas 对每一列数据进行标准化的方法
2018/06/09 Python
浅析python3字符串格式化format()函数的简单用法
2018/12/07 Python
Python异常处理例题整理
2019/07/07 Python
python可视化实现KNN算法
2019/10/16 Python
Python和Sublime整合过程图示
2019/12/25 Python
python连接mysql数据库并读取数据的实现
2020/09/25 Python
python中四舍五入的正确打开方式
2021/01/18 Python
StubHub德国:购买和出售门票
2017/09/06 全球购物
几道PHP面试题
2013/04/14 面试题
Java里面StringBuilder和StringBuffer有什么区别
2016/06/06 面试题
门卫岗位职责说明书
2014/08/18 职场文书
环卫工人节活动总结
2014/08/29 职场文书
小学英语课教学反思
2016/02/15 职场文书
详解Flutter和Dart取消Future的三种方法
2022/04/07 Java/Android