Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现bitmap数据结构详解
Feb 17 Python
Python日志模块logging简介
Apr 13 Python
python在windows下创建隐藏窗口子进程的方法
Jun 04 Python
Django中处理出错页面的方法
Jul 15 Python
利用pandas合并多个excel的方法示例
Oct 10 Python
Python程序暂停的正常处理方法
Nov 07 Python
使用pyqt 实现重复打开多个相同界面
Dec 13 Python
Python图像处理库PIL的ImageGrab模块介绍详解
Feb 26 Python
Python-opencv实现红绿两色识别操作
Jun 04 Python
Python结合Window计划任务监测邮件的示例代码
Aug 05 Python
tensorflow+k-means聚类简单实现猫狗图像分类的方法
Apr 28 Python
Python if else条件语句形式详解
Mar 24 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
PHP iconv 解决utf-8和gb2312编码转换问题
2010/04/12 PHP
php的一些小问题
2010/07/03 PHP
PHP排序之二维数组的按照字母排序实现代码
2011/08/13 PHP
PHP隐形一句话后门,和ThinkPHP框架加密码程序(base64_decode)
2011/11/02 PHP
PHP 下载文件时自动添加bom头的方法实例
2014/01/10 PHP
Thinkphp中的volist标签用法简介
2014/06/18 PHP
Yii2实现同时搜索多个字段的方法
2016/08/10 PHP
js获取电脑分辨率的思路及操作
2013/11/22 Javascript
jquery的live使用注意事项
2014/02/18 Javascript
JavaScript实现控制打开文件另存为对话框的方法
2015/04/17 Javascript
js实现仿MSN带关闭功能的右下角弹窗代码
2015/09/04 Javascript
Angular2表单自定义验证器的实现
2016/10/19 Javascript
JavaScript函数表达式详解及实例
2017/05/05 Javascript
vue底部加载更多的实例代码
2018/06/29 Javascript
Javascript删除数组里的某个元素
2019/02/28 Javascript
angularjs请求数据的方法示例
2019/08/06 Javascript
使用Layui搭建后台管理界面的操作方法
2019/09/20 Javascript
python Matplotlib画图之调整字体大小的示例
2017/11/20 Python
tensorflow TFRecords文件的生成和读取的方法
2018/02/06 Python
Python中pillow知识点学习
2018/04/30 Python
Python for循环生成列表的实例
2018/06/15 Python
python的pip安装以及使用教程
2018/09/18 Python
对Python生成汉字字库文字,以及转换为文字图片的实例详解
2019/01/29 Python
Python简易版图书管理系统
2019/08/12 Python
使用python的turtle绘画滑稽脸实例
2019/11/21 Python
keras 自定义loss层+接受输入实例
2020/06/28 Python
利用Python pandas对Excel进行合并的方法示例
2020/11/04 Python
opencv实现图像几何变换
2021/03/24 Python
《少年王勃》教学反思
2014/04/27 职场文书
欢迎词范文
2015/01/27 职场文书
实习单位鉴定意见
2015/06/04 职场文书
清明节主题班会
2015/08/14 职场文书
文艺有韵味的诗句(生命类、亲情类...)
2019/07/11 职场文书
高中语文教材(文学文化常识大全一)
2019/08/13 职场文书
Python实现归一化算法详情
2022/03/18 Python
Redis监控工具RedisInsight安装与使用
2022/03/21 Redis