pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
自动化Nginx服务器的反向代理的配置方法
Jun 28 Python
浅谈Django学习migrate和makemigrations的差别
Jan 18 Python
python中的set实现不重复的排序原理
Jan 24 Python
用tensorflow构建线性回归模型的示例代码
Mar 05 Python
Python操作redis实例小结【String、Hash、List、Set等】
May 16 Python
Python学习笔记之字符串和字符串方法实例详解
Aug 22 Python
Python使用百度api做人脸对比的方法
Aug 28 Python
python 装饰器功能与用法案例详解
Mar 06 Python
xadmin使用formfield_for_dbfield函数过滤下拉表单实例
Apr 07 Python
详解python中GPU版本的opencv常用方法介绍
Jul 24 Python
Python爬虫之Selenium设置元素等待的方法
Dec 04 Python
python实现会员管理系统
Mar 18 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
php INI配置文件的解析实现分析
2011/01/04 PHP
php截取html字符串及自动补全html标签的方法
2015/01/15 PHP
CI操作cookie的方法分析(基于helper类库)
2016/03/28 PHP
Yii控制器中操作视图js的方法
2016/07/04 PHP
PHP中->和=>的含义及使用示例解析
2020/08/06 PHP
js操作textarea方法集合封装(兼容IE,firefox)
2011/02/22 Javascript
javascript里使用php代码实例
2014/12/13 Javascript
javascript中的正则表达式使用指南
2015/03/01 Javascript
jQuery模仿京东/天猫商品左侧分类导航菜单效果
2016/06/29 Javascript
AngularJS  双向数据绑定详解简单实例
2016/10/20 Javascript
Node.js开启Https的实践详解
2016/10/25 Javascript
AngularJS实现自定义指令及指令配置项的方法
2017/11/20 Javascript
在Vue项目中,防止页面被缩放和放大示例
2019/10/28 Javascript
Vue.js实现大屏数字滚动翻转效果
2019/11/29 Javascript
使用js实现单链解决前端队列问题的方法
2020/02/03 Javascript
JavaScript数组排序的六种常见算法总结
2020/08/18 Javascript
[37:03]完美世界DOTA2联赛PWL S3 INK ICE vs GXR 第二场 12.16
2020/12/18 DOTA
python3 爬取图片的实例代码
2018/11/06 Python
Python3.5实现的三级菜单功能示例
2019/03/25 Python
在Qt5和PyQt5中设置支持高分辨率屏幕自适应的方法
2019/06/18 Python
python程序 创建多线程过程详解
2019/09/23 Python
如何更改 pandas dataframe 中两列的位置
2019/12/27 Python
python 已知一个字符,在一个list中找出近似值或相似值实现模糊匹配
2020/02/29 Python
Python小白垃圾回收机制入门
2020/06/09 Python
python BeautifulSoup库的安装与使用
2020/12/17 Python
美国顶级奢侈茶:Mighty Leaf Tea(美泰茶)
2016/11/26 全球购物
白色公司:The White Company
2017/10/11 全球购物
创先争优活动方案
2014/02/12 职场文书
公司中秋节活动方案
2014/02/12 职场文书
《爱如茉莉》教后反思
2014/04/12 职场文书
陈胜吴广起义口号
2014/06/20 职场文书
2014入党积极分子批评与自我批评思想报告
2014/10/06 职场文书
2015年妇联工作总结范文
2015/04/22 职场文书
2022年四月新番
2022/03/15 日漫
Python实现Hash算法
2022/03/18 Python
vue项目打包后路由错误的解决方法
2022/04/13 Vue.js