pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python错误:AttributeError: 'module' object has no attribute 'setdefaultencoding'问题的解决方法
Aug 22 Python
使用Python发送各种形式的邮件的方法汇总
Nov 09 Python
使用Python简单的实现树莓派的WEB控制
Feb 18 Python
python正则中最短匹配实现代码
Jan 16 Python
Anaconda下安装mysql-python的包实例
Jun 11 Python
一文了解Python并发编程的工程实现方法
May 31 Python
简单的Python调度器Schedule详解
Aug 30 Python
python中调试或排错的五种方法示例
Sep 12 Python
PIL包中Image模块的convert()函数的具体使用
Feb 26 Python
Python Unittest原理及基本使用方法
Nov 06 Python
python爬取豆瓣电影TOP250数据
May 23 Python
pytorch 如何把图像数据集进行划分成train,test和val
May 31 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
获得Google PR值的PHP代码
2007/01/28 PHP
php实现utf-8转unicode函数分享
2015/01/06 PHP
php类的扩展和继承用法实例
2015/06/20 PHP
Yii2 批量插入、更新数据实例
2017/03/15 PHP
PHP从数组中删除元素的四种方法实例
2017/05/12 PHP
ThinkPHP3.2.3框架实现执行原生SQL语句的方法示例
2019/04/03 PHP
jQuery 菜单随滚条改为以定位方式(固定要浏览器顶部)
2012/05/24 Javascript
jQuery Ajax提交表单查询获得数据实例代码
2012/09/19 Javascript
javascript 循环调用示例介绍
2013/11/20 Javascript
js仿支付宝多方框输入支付密码效果
2016/09/27 Javascript
使用jQuery ajaxupload插件实现无刷新上传文件
2017/04/23 jQuery
详解关于Vuex的action传入多个参数的问题
2019/02/22 Javascript
javascript面向对象三大特征之多态实例详解
2019/07/24 Javascript
Vue CLI项目 axios模块前后端交互的使用(类似ajax提交)
2019/09/01 Javascript
vue进入页面时不在顶部,检测滚动返回顶部按钮问题及解决方法
2019/10/30 Javascript
[05:00]第二届DOTA2亚洲邀请赛主赛事第三天比赛集锦.mp4
2017/04/04 DOTA
利用Python实现Windows定时关机功能
2017/03/21 Python
Python获取SQLite查询结果表列名的方法
2017/06/21 Python
ubuntu环境下python虚拟环境的安装过程
2018/01/07 Python
Python实现的爬取网易动态评论操作示例
2018/06/06 Python
深入浅析Python中list的复制及深拷贝与浅拷贝
2018/09/03 Python
python实现简单的文字识别
2018/11/27 Python
Python企业编码生成系统之系统主要函数设计详解
2019/07/26 Python
python写程序统计词频的方法
2019/07/29 Python
html5 Canvas实现图片旋转的示例
2018/01/15 HTML / CSS
校园广播稿500字
2014/02/04 职场文书
畜牧兽医本科生的自我评价
2014/03/03 职场文书
《会走路的树》教后反思
2014/04/19 职场文书
读书活动总结范文
2014/04/26 职场文书
大专生求职信
2014/06/29 职场文书
新颖的化妆品活动方案
2014/08/21 职场文书
个人四风对照检查材料
2014/09/26 职场文书
党员个人查摆剖析材料
2014/10/16 职场文书
优秀员工自荐书
2015/03/06 职场文书
2015年客房服务员工作总结
2015/05/15 职场文书
个人售房合同协议书
2016/03/21 职场文书