pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python开发编码规范
Sep 08 Python
50行Python代码实现人脸检测功能
Jan 23 Python
Python处理CSV与List的转换方法
Apr 19 Python
Python logging模块用法示例
Aug 28 Python
Python numpy中矩阵的基本用法汇总
Feb 12 Python
python gensim使用word2vec词向量处理中文语料的方法
Jul 05 Python
通过PHP与Python代码对比的语法差异详解
Jul 10 Python
python中enumerate() 与zip()函数的使用比较实例分析
Sep 03 Python
python提取xml里面的链接源码详解
Oct 15 Python
简单了解python装饰器原理及使用方法
Dec 18 Python
Python打包模块wheel的使用方法与将python包发布到PyPI的方法详解
Feb 12 Python
Python基于pyjnius库实现访问java类
Jul 31 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
WINDOWS服务器安装多套PHP的另类解决方案
2006/10/09 PHP
php读取富文本的时p标签会出现红线是怎么回事
2014/05/13 PHP
网站防止被刷票的一些思路与方法
2015/01/08 PHP
php简单判断文本编码的方法
2015/07/30 PHP
javascript中的对象和数组的应用技巧
2007/01/07 Javascript
jfreechart插件将数据展示成饼状图、柱状图和折线图
2015/04/13 Javascript
JavaScript实现简单获取当前网页网址的方法
2015/11/09 Javascript
jQuery实现本地预览上传图片功能
2016/01/08 Javascript
详解Angular开发中的登陆与身份验证
2016/07/27 Javascript
JS获取中文拼音首字母并通过拼音首字母快速查找页面内对应中文内容的方法【附demo源码】
2016/08/19 Javascript
关于RequireJS的简单介绍即使用方法
2016/10/20 Javascript
bootstrap datetimepicker控件位置异常的解决方法
2017/11/23 Javascript
vue2.0 实现导航守卫的具体用法(路由守卫)
2018/05/17 Javascript
CKEditor4配置与开发详细中文说明文档
2018/10/08 Javascript
Vue 报错TypeError: this.$set is not a function 的解决方法
2018/12/17 Javascript
微信小程序 点击切换样式scroll-view实现代码实例
2019/10/11 Javascript
javascript实现贪吃蛇经典游戏
2020/04/10 Javascript
vue如何在项目中调用腾讯云的滑动验证码
2020/07/15 Javascript
[03:47]2015国际邀请赛第三日现场精彩回顾
2015/08/08 DOTA
python学习之第三方包安装方法(两种方法)
2015/07/30 Python
python 读文件,然后转化为矩阵的实例
2018/04/23 Python
Python 内置函数进制转换的用法(十进制转二进制、八进制、十六进制)
2018/04/30 Python
Django框架实现的简单分页功能示例
2018/12/04 Python
Pytorch中的variable, tensor与numpy相互转化的方法
2019/10/10 Python
什么是Python中的顺序表
2020/06/02 Python
selenium如何定位span元素的实现
2021/01/13 Python
纯CSS3实现鼠标滑过按钮动画第二节
2020/07/16 HTML / CSS
NBA欧洲商店(法国):NBA Europe Store FR
2016/10/19 全球购物
美国最大的网上冲印店:Shutterfly
2017/01/01 全球购物
售房协议书范本2014
2014/10/23 职场文书
自我评价优缺点范文
2015/03/11 职场文书
2015年酒店工作总结范文
2015/04/07 职场文书
创业计划书之美甲店
2019/09/20 职场文书
对PyTorch中inplace字段的全面理解
2021/05/22 Python
深入理解以DEBUG方式线程的底层运行原理
2021/06/21 Java/Android
MySQL下载安装配置详细教程 附下载资源
2022/09/23 MySQL