pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用poplib模块和smtplib模块收发电子邮件的教程
Jul 02 Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 Python
Python实现按当前日期(年、月、日)创建多级目录的方法
Apr 26 Python
使用Python实现在Windows下安装Django
Oct 17 Python
在Pycharm terminal中字体大小设置的方法
Jan 16 Python
Python3 io文本及原始流I/O工具用法详解
Mar 23 Python
Python将二维列表list的数据输出(TXT,Excel)
Apr 23 Python
python语言中有算法吗
Jun 16 Python
谈谈python垃圾回收机制
Sep 27 Python
python跨文件使用全局变量的实现
Nov 17 Python
解决pytorch 数据类型报错的问题
Mar 03 Python
python自动化之如何利用allure生成测试报告
May 02 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
用PHP实现Ftp用户的在线管理
2012/02/16 PHP
php 注释规范
2012/03/29 PHP
PHP中error_log()函数的使用方法
2015/01/20 PHP
PHP SPL标准库之接口(Interface)详解
2015/05/11 PHP
PHP+JS实现的实时搜索提示功能
2018/03/13 PHP
js弹出框轻量级插件jquery.boxy使用介绍
2013/01/15 Javascript
jquery如何通过name名称获取当前name的value值
2013/12/20 Javascript
多选列表框动态添加,移动,删除,全选等操作的简单实例
2014/01/13 Javascript
JS实现的页面自定义滚动条效果
2015/10/26 Javascript
浅谈angular懒加载的一些坑
2016/08/20 Javascript
JS冒泡事件与事件捕获实例详解
2016/11/25 Javascript
JavaScript使用delete删除数组元素用法示例【数组长度不变】
2017/01/17 Javascript
面试常见的js算法题
2017/03/23 Javascript
Angularjs中使用轮播图指令swiper
2017/05/30 Javascript
JavaScript正则表达式校验与递归函数实际应用实例解析
2017/08/04 Javascript
原生JS实现网页手机音乐播放器 歌词同步播放的示例
2018/02/02 Javascript
angular2模块和共享模块详解
2018/04/08 Javascript
vuex 动态注册方法 registerModule的实现
2019/07/03 Javascript
vue实现下拉菜单树
2020/10/22 Javascript
Javascript表单序列化原理及实现代码详解
2020/10/30 Javascript
pycharm 使用心得(二)设置字体大小
2014/06/05 Python
Windows下为Python安装Matplotlib模块
2015/11/06 Python
Windows上使用virtualenv搭建Python+Flask开发环境
2016/06/07 Python
python机器学习实战之K均值聚类
2017/12/20 Python
python XlsxWriter模块创建aexcel表格的实例讲解
2018/05/03 Python
Python根据指定日期计算后n天,前n天是哪一天的方法
2018/05/29 Python
Python调用服务接口的实例
2019/01/03 Python
python numpy矩阵信息说明,shape,size,dtype
2020/05/22 Python
菲律宾最大的网上花店和礼品店:PhilFlower.com
2018/02/09 全球购物
remote接口和home接口主要作用
2013/05/15 面试题
英语简历自我评价
2014/01/26 职场文书
写得不错的求职信范文
2014/07/11 职场文书
常住证明范本
2015/06/23 职场文书
优秀党员主要事迹材料
2015/11/04 职场文书
建房合同协议书
2016/03/21 职场文书
Nginx反向代理、重定向
2022/04/13 Servers