Pytorch 实现focal_loss 多类别和二分类示例


Posted in Python onJanuary 14, 2020

我就废话不多说了,直接上代码吧!

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
# 支持多分类和二分类
class FocalLoss(nn.Module):
  """
  This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
  'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
    Focal_Loss= -1*alpha*(1-pt)^gamma*log(pt)
  :param num_class:
  :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
  :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
          focus on hard misclassified example
  :param smooth: (float,double) smooth value when cross entropy
  :param balance_index: (int) balance class index, should be specific when alpha is float
  :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
  """
 
  def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True):
    super(FocalLoss, self).__init__()
    self.num_class = num_class
    self.alpha = alpha
    self.gamma = gamma
    self.smooth = smooth
    self.size_average = size_average
 
    if self.alpha is None:
      self.alpha = torch.ones(self.num_class, 1)
    elif isinstance(self.alpha, (list, np.ndarray)):
      assert len(self.alpha) == self.num_class
      self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1)
      self.alpha = self.alpha / self.alpha.sum()
    elif isinstance(self.alpha, float):
      alpha = torch.ones(self.num_class, 1)
      alpha = alpha * (1 - self.alpha)
      alpha[balance_index] = self.alpha
      self.alpha = alpha
    else:
      raise TypeError('Not support alpha type')
 
    if self.smooth is not None:
      if self.smooth < 0 or self.smooth > 1.0:
        raise ValueError('smooth value should be in [0,1]')
 
  def forward(self, input, target):
    logit = F.softmax(input, dim=1)
 
    if logit.dim() > 2:
      # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
      logit = logit.view(logit.size(0), logit.size(1), -1)
      logit = logit.permute(0, 2, 1).contiguous()
      logit = logit.view(-1, logit.size(-1))
    target = target.view(-1, 1)
 
    # N = input.size(0)
    # alpha = torch.ones(N, self.num_class)
    # alpha = alpha * (1 - self.alpha)
    # alpha = alpha.scatter_(1, target.long(), self.alpha)
    epsilon = 1e-10
    alpha = self.alpha
    if alpha.device != input.device:
      alpha = alpha.to(input.device)
 
    idx = target.cpu().long()
    one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
    one_hot_key = one_hot_key.scatter_(1, idx, 1)
    if one_hot_key.device != logit.device:
      one_hot_key = one_hot_key.to(logit.device)
 
    if self.smooth:
      one_hot_key = torch.clamp(
        one_hot_key, self.smooth, 1.0 - self.smooth)
    pt = (one_hot_key * logit).sum(1) + epsilon
    logpt = pt.log()
 
    gamma = self.gamma
 
    alpha = alpha[idx]
    loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
 
    if self.size_average:
      loss = loss.mean()
    else:
      loss = loss.sum()
    return loss
 
 
 
class BCEFocalLoss(torch.nn.Module):
  """
  二分类的Focalloss alpha 固定
  """
  def __init__(self, gamma=2, alpha=0.25, reduction='elementwise_mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction
 
  def forward(self, _input, target):
    pt = torch.sigmoid(_input)
    alpha = self.alpha
    loss = - alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - \
        (1 - alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt)
    if self.reduction == 'elementwise_mean':
      loss = torch.mean(loss)
    elif self.reduction == 'sum':
      loss = torch.sum(loss)
    return loss

以上这篇Pytorch 实现focal_loss 多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python标准库之多进程(multiprocessing包)介绍
Nov 25 Python
python检测远程udp端口是否打开的方法
Mar 14 Python
在Python的Flask框架中使用日期和时间的教程
Apr 21 Python
Python环境下搭建属于自己的pip源的教程
May 05 Python
Python实现去除列表中重复元素的方法小结【4种方法】
Apr 27 Python
Python3中在Anaconda环境下安装basemap包
Oct 21 Python
Python中一个for循环循环多个变量的示例
Jul 16 Python
python使用 cx_Oracle 模块进行查询操作示例
Nov 28 Python
Python实现清理微信僵尸粉功能示例【基于itchat模块】
May 29 Python
Python模块zipfile原理及使用方法详解
Aug 04 Python
python爬取音频下载的示例代码
Oct 19 Python
Django框架中模型的用法
Jun 10 Python
Python实现钉钉订阅消息功能
Jan 14 #Python
Python Tensor FLow简单使用方法实例详解
Jan 14 #Python
Python利用全连接神经网络求解MNIST问题详解
Jan 14 #Python
基于pytorch的lstm参数使用详解
Jan 14 #Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
You might like
php中通过smtp发邮件的类,测试通过
2007/01/22 PHP
发布一个迷你php+AJAX聊天程序[聊天室]提供下载
2007/07/21 PHP
centos 5.6 升级php到5.3的方法
2011/05/14 PHP
php自定义函数br2nl实现将html中br换行符转换为文本输入中换行符的方法【与函数nl2br功能相反】
2017/02/17 PHP
PHP ElasticSearch做搜索实例讲解
2020/02/05 PHP
javascript onkeydown,onkeyup,onkeypress,onclick,ondblclick
2009/02/04 Javascript
理解JavaScript变量作用域更轻松
2009/10/25 Javascript
JavaScript 对象链式操作测试代码
2010/04/25 Javascript
网站如何做到完全不需要jQuery也可以满足简单需求
2013/06/27 Javascript
node.js中的fs.unlink方法使用说明
2014/12/15 Javascript
JS+CSS实现感应鼠标渐变显示DIV层的方法
2015/02/20 Javascript
JS烟花背景效果实现方法
2015/03/03 Javascript
酷炫jQuery全屏3D焦点图动画效果
2016/03/22 Javascript
js基本算法:冒泡排序,二分查找的简单实例
2016/10/08 Javascript
微信小程序 教程之引用
2016/10/18 Javascript
微信小程序 loading 详解及实例代码
2016/11/09 Javascript
bootstrap jquery dataTable 异步ajax刷新表格数据的实现方法
2017/02/10 Javascript
jQuery扇形定时器插件pietimer使用方法详解
2017/07/18 jQuery
Vue学习笔记进阶篇之vue-router安装及使用方法
2017/07/19 Javascript
Angular2使用vscode断点调试ts文件的方法
2017/12/13 Javascript
es6 symbol的实现方法示例
2019/04/02 Javascript
JavaScript如何判断input数据类型
2020/02/06 Javascript
30分钟搭建Python的Flask框架并在上面编写第一个应用
2015/03/30 Python
Python中使用platform模块获取系统信息的用法教程
2016/07/08 Python
python3.6的venv模块使用详解
2018/08/01 Python
基于Python和C++实现删除链表的节点
2020/07/06 Python
Python字典fromkeys()方法使用代码实例
2020/07/20 Python
使用CSS3设计地图上的雷达定位提示效果
2016/04/05 HTML / CSS
印尼在线精品店:Berrybenka.com
2016/10/22 全球购物
解释一下钝化(Swap out)
2016/12/26 面试题
税务会计岗位职责
2014/02/18 职场文书
铁路安全反思材料
2014/12/24 职场文书
2016年校园重阳节广播稿
2015/12/18 职场文书
幼儿园语言教学反思
2016/02/23 职场文书
php远程请求CURL案例(爬虫、保存登录状态)
2021/04/01 PHP
英国数字版游戏销量周榜公布 《小缇娜的奇幻之地》登顶
2022/04/03 其他游戏