pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python代码来解图片迷宫的方法整理
Apr 02 Python
python结合shell查询google关键词排名的实现代码
Feb 27 Python
Python列表切片用法示例
Apr 19 Python
浅谈Python2.6和Python3.0中八进制数字表示的区别
Apr 28 Python
详解如何使用Python编写vim插件
Nov 28 Python
Python基于OpenCV库Adaboost实现人脸识别功能详解
Aug 25 Python
Python实现将字符串的首字母变为大写,其余都变为小写的方法
Jun 11 Python
Python 网络编程之UDP发送接收数据功能示例【基于socket套接字】
Oct 11 Python
浅析Python面向对象编程
Jul 10 Python
pyqt5 textEdit、lineEdit操作的示例代码
Aug 12 Python
Python将list元素转存为CSV文件的实现
Nov 16 Python
在pycharm中无法import所安装的库解决方案
May 31 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
在PHP中养成7个面向对象的好习惯
2010/07/17 PHP
DEDE采集大师官方留后门的删除办法
2011/01/08 PHP
PHP文件管理之实现网盘及压缩包的功能操作
2017/09/20 PHP
PHP信号处理机制的操作代码讲解
2019/04/19 PHP
使用Entrust扩展包在laravel 中实现RBAC的功能
2020/03/16 PHP
javascript日期转换 时间戳转日期格式
2011/11/05 Javascript
javascript实现数独解法
2015/03/14 Javascript
jquery显示loading图片直到网页加载完成的方法
2015/06/25 Javascript
JavaScript实现可拖拽的拖动层Div实例
2015/08/05 Javascript
javascript中eval和with用法实例总结
2015/11/30 Javascript
javascript计时器编写过程与实现方法
2016/02/29 Javascript
JavaScript中windows.open()、windows.close()方法详解
2016/07/28 Javascript
Angularjs 实现分页功能及示例代码
2016/09/14 Javascript
JS控件bootstrap datepicker使用方法详解
2017/03/25 Javascript
了解VUE的render函数的使用
2017/06/08 Javascript
基于require.js的使用(实例讲解)
2017/09/07 Javascript
jQuery实现简单的下拉菜单导航功能示例
2017/12/07 jQuery
elementUI中Table表格问题的解决方法
2018/12/04 Javascript
js实现登录时记住密码的方法分析
2020/04/05 Javascript
详解JavaScript之ES5的继承
2020/07/08 Javascript
js实现盒子移动动画效果
2020/08/09 Javascript
python网络编程学习笔记(四):域名系统
2014/06/09 Python
关于你不想知道的所有Python3 unicode特性
2014/11/28 Python
Python中列表和元组的使用方法和区别详解
2020/12/30 Python
python中字符串类型json操作的注意事项
2017/05/02 Python
python Pexpect 实现输密码 scp 拷贝的方法
2019/01/03 Python
python将四元数变换为旋转矩阵的实例
2019/12/04 Python
tensorflow查看ckpt各节点名称实例
2020/01/21 Python
python 读取二进制 显示图片案例
2020/04/24 Python
Sentry错误日志监控使用方法解析
2020/11/12 Python
泰国演唱会订票网站:StubHub泰国
2018/02/26 全球购物
我未来的职业规划范文
2014/01/11 职场文书
学生周末长期请假条
2014/02/15 职场文书
受伤赔偿协议书
2014/09/24 职场文书
2014年科技工作总结
2014/11/26 职场文书
2016年寒假学习心得体会
2015/10/09 职场文书