Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈python中scipy.misc.logsumexp函数的运用场景
Jun 23 Python
详解Python开发中如何使用Hook技巧
Nov 01 Python
Python读取图片为16进制表示简单代码
Jan 19 Python
Flask之flask-script模块使用
Jul 26 Python
Python图像处理之简单画板实现方法示例
Aug 30 Python
Python 3.x基于Xml数据的Http请求方法
Dec 28 Python
django配置连接数据库及原生sql语句的使用方法
Mar 03 Python
使用matlab 判断两个矩阵是否相等的实例
May 11 Python
python实现猜单词游戏
May 22 Python
提高python代码运行效率的一些建议
Sep 29 Python
matplotlib自定义鼠标光标坐标格式的实现
Jan 08 Python
Python自动化测试基础必备知识点总结
Feb 07 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
PHP中基本符号及使用方法
2010/03/23 PHP
PHP序列号生成函数和字符串替换函数代码
2012/06/07 PHP
php生成验证码函数
2015/10/20 PHP
PHP 7安装调试工具Xdebug扩展的方法教程
2017/06/17 PHP
PHP var关键字相关原理及使用实例解析
2020/07/11 PHP
CLASS_CONFUSION JS混淆 全源码
2007/12/12 Javascript
jQuery select控制插件
2009/08/17 Javascript
extjs grid设置某列背景颜色和字体颜色的实现方法
2010/09/06 Javascript
jQuery实现类似淘宝购物车全选状态示例
2013/06/26 Javascript
jQuery获得内容和属性方法及示例
2013/12/02 Javascript
分享纯手写漂亮的表单验证
2015/11/19 Javascript
jQuery技巧之让任何组件都支持类似DOM的事件管理
2016/04/05 Javascript
JavaScript输入分钟、秒倒计时技巧总结(附代码)
2017/08/17 Javascript
详解Node.js利用node-git-server快速搭建git服务器
2017/09/27 Javascript
详解如何在vue项目中使用eslint+prettier格式化代码
2018/11/10 Javascript
layui问题之模拟table表格中的选中按钮选中事件的方法
2019/09/20 Javascript
jQuery实现鼠标滑动切换图片
2020/05/27 jQuery
[00:37]DOTA2上海特级锦标赛 OG战队宣传片
2016/03/03 DOTA
Python threading多线程编程实例
2014/09/18 Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
2018/05/24 Python
python reverse反转部分数组的实例
2018/12/13 Python
Python  Django 母版和继承解析
2019/08/09 Python
10分钟教你用python动画演示深度优先算法搜寻逃出迷宫的路径
2019/08/12 Python
纯css3实现效果超级炫的checkbox复选框和radio单选框
2014/09/01 HTML / CSS
南威尔士家居商店:Leekes
2016/10/25 全球购物
Chantelle仙黛尔内衣美国官网:法国第一品牌内衣
2018/07/26 全球购物
WebSphere面试题:在WebSphere里面如何部署一个应用
2015/08/02 面试题
KTV的创业计划书范文
2014/02/02 职场文书
五分钟演讲稿
2014/04/30 职场文书
10的分与合教学反思
2014/04/30 职场文书
员工安全承诺书
2014/05/22 职场文书
户籍证明格式
2014/09/15 职场文书
大学生党员个人总结
2015/02/13 职场文书
劳动仲裁调解书
2015/05/20 职场文书
开学典礼观后感
2015/06/15 职场文书
《天使的翅膀》读后感3篇
2019/12/20 职场文书