pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python爬虫入门教程之点点美女图片爬虫代码分享
Sep 02 Python
python爬虫headers设置后无效的解决方法
Oct 21 Python
Flask-Mail用法实例分析
Jul 21 Python
pandas 条件搜索返回列表的方法
Oct 30 Python
python 动态生成变量名以及动态获取变量的变量名方法
Jan 20 Python
python实现剪切功能
Jan 23 Python
Python中函数的返回值示例浅析
Aug 28 Python
pycharm部署、配置anaconda环境的教程
Mar 24 Python
PyQt5实现登录页面
May 30 Python
Python while true实现爬虫定时任务
Jun 08 Python
使用keras时input_shape的维度表示问题说明
Jun 29 Python
python 模块重载的五种方法
Apr 24 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
如何删除多级目录
2006/10/09 PHP
在PHP中PDO解决中文乱码问题的一些补充
2010/09/06 PHP
php跨域cookie共享使用方法
2014/02/20 PHP
Javascript & DHTML 实例编程(教程)(三)初级实例篇1—上传文件控件实例
2007/06/02 Javascript
基于JQuery制作的产品广告效果
2010/12/08 Javascript
onkeypress字符按键兼容所有浏览器使用介绍
2013/04/24 Javascript
7个有用的jQuery代码片段分享
2015/05/19 Javascript
jQuery实现首页图片淡入淡出效果的方法
2015/06/10 Javascript
JQuery菜单效果的两个实例讲解(3)
2015/09/17 Javascript
浅谈JS原型对象和原型链
2016/03/02 Javascript
js创建jsonArray传输至后台及后台全面解析
2016/04/11 Javascript
jQuery+php实时获取及响应文本框输入内容的方法
2016/05/24 Javascript
jquery 中toggle的2种用法详解(推荐)
2016/09/02 Javascript
jQuery Ajax请求后台数据并在前台接收
2016/12/10 Javascript
jq checkbox 的全选并ajax传参的实例
2017/04/01 Javascript
JavaScript数据结构之二叉树的遍历算法示例
2017/04/13 Javascript
妙用Angularjs实现表格按指定列排序
2017/06/23 Javascript
利用Javascript开发一个二维周视图日历
2017/12/14 Javascript
React实现全局组件的Toast轻提示效果
2018/09/21 Javascript
代码实例ajax实现点击加载更多数据图片
2018/10/12 Javascript
puppeteer实现html截图的示例代码
2019/01/10 Javascript
详解Python中dict与set的使用
2015/08/10 Python
Python3实现的Mysql数据库操作封装类
2018/06/06 Python
python opencv实现运动检测
2018/07/10 Python
Python使用sort和class实现的多级排序功能示例
2018/08/15 Python
简单介绍python封装的基本知识
2019/08/10 Python
通过字符串导入 Python 模块的方法详解
2019/10/27 Python
跑步爱好者一站式服务网站:Jack Rabbit
2016/09/01 全球购物
Lentiamo丹麦:购买便宜的隐形眼镜
2021/01/13 全球购物
小班重阳节活动方案
2014/02/08 职场文书
竞选班干部演讲稿
2014/04/24 职场文书
机关干部个人对照检查材料思想汇报
2014/09/28 职场文书
领导干部个人整改措施落实情况汇报
2014/10/29 职场文书
受资助学生感谢信
2015/01/21 职场文书
Python实现文本文件拆分写入到多个文本文件的方法
2021/04/18 Python
Java中try catch处理异常示例
2021/12/06 Java/Android