pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进程管理工具supervisor使用实例
Sep 17 Python
使用Python解析JSON数据的基本方法
Oct 15 Python
在Windows系统上搭建Nginx+Python+MySQL环境的教程
Dec 25 Python
Python中的pygal安装和绘制直方图代码分享
Dec 08 Python
Python 修改列表中的元素方法
Jun 26 Python
Python实现从SQL型数据库读写dataframe型数据的方法【基于pandas】
Mar 18 Python
pyqt 实现在Widgets中显示图片和文字的方法
Jun 13 Python
python算法与数据结构之冒泡排序实例详解
Jun 22 Python
Python如何读写字节数据
Aug 05 Python
pycharm-professional-2020.1下载与激活的教程
Sep 21 Python
python 邮件检测工具mmpi的使用
Jan 04 Python
浅谈怎么给Python添加类型标注
Jun 08 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
PHP学习 运算符与运算符优先级
2008/06/15 PHP
php&mysql 日期操作小记
2012/02/27 PHP
php性能分析之php-fpm慢执行日志slow log用法浅析
2016/10/17 PHP
PHP编程实现csv文件导入mysql数据库的方法
2017/04/29 PHP
PHP For循环字母A-Z当超过26个字母时输出AA,AB,AC
2020/02/16 PHP
Centos7安装swoole扩展操作示例
2020/03/26 PHP
PhpStorm2020 + phpstudyV8 +XDebug的教程详解
2020/09/17 PHP
一段非常简单的让图片自动切换js代码
2006/11/10 Javascript
javascript import css实例代码
2008/07/18 Javascript
深入理解JavaScript系列(14) 作用域链介绍(Scope Chain)
2012/04/12 Javascript
JS拖拽插件实现步骤
2015/08/03 Javascript
jQuery 获取屏幕高度、宽度的简单实现案例
2016/05/17 Javascript
jQuery实现可拖拽的许愿墙效果【附demo源码下载】
2016/09/14 Javascript
基于JavaScript实现右键菜单和拖拽功能
2016/11/28 Javascript
[Bootstrap-插件使用]Jcrop+fileinput组合实现头像上传功能实例代码
2016/12/20 Javascript
微信小程序 input表单与redio及下拉列表的使用实例
2017/09/20 Javascript
jQuery实现动态控制页面元素的方法分析
2017/12/20 jQuery
bootstrap模态框关闭后清除模态框的数据方法
2018/08/10 Javascript
vue 右键菜单插件 简单、可扩展、样式自定义的右键菜单
2018/11/29 Javascript
[57:12]完美世界DOTA2联赛循环赛 Inki vs Matador BO2第一场 10.31
2020/11/02 DOTA
[03:12]完美世界DOTA2联赛PWL DAY9集锦
2020/11/10 DOTA
python计算圆周率pi的方法
2015/07/11 Python
pandas的object对象转时间对象的方法
2018/04/11 Python
学生信息管理系统Python面向对象版
2019/01/30 Python
pytorch中的transforms模块实例详解
2019/12/31 Python
python实现在线翻译功能
2020/03/03 Python
Python制作一个仿QQ办公版的图形登录界面
2020/09/22 Python
Python关于拓扑排序知识点讲解
2021/01/04 Python
前后端结合实现amazeUI分页效果
2020/08/21 HTML / CSS
小学生自我鉴定
2013/10/12 职场文书
自荐信的五个重要部分
2013/10/29 职场文书
小学生寒假家长评语
2014/04/16 职场文书
超市商业计划书
2014/05/04 职场文书
七夕相亲活动策划方案
2014/08/31 职场文书
综治目标管理责任书
2015/05/11 职场文书
教师年度考核自我评鉴
2015/08/11 职场文书