pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
进一步探究Python的装饰器的运用
May 05 Python
浅谈python 读excel数值为浮点型的问题
Dec 25 Python
pandas修改DataFrame列名的实现方法
Feb 22 Python
使用PyQtGraph绘制精美的股票行情K线图的示例代码
Mar 14 Python
关于Python作用域自学总结
Jun 10 Python
python网络编程 使用UDP、TCP协议收发信息详解
Aug 29 Python
PyTorch的SoftMax交叉熵损失和梯度用法
Jan 15 Python
Python requests模块基础使用方法实例及高级应用(自动登陆,抓取网页源码)实例详解
Feb 14 Python
Pycharm 安装 idea VIM插件的图文教程详解
Feb 21 Python
python GUI库图形界面开发之PyQt5中QMainWindow, QWidget以及QDialog的区别和选择
Feb 26 Python
pycharm下pyqt4安装及环境配置的教程
Apr 24 Python
matplotlib更改窗口图标的方法示例
Feb 03 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
用缓存实现静态页面的测试
2006/12/06 PHP
PHP array_shift()用法实例分析
2019/01/07 PHP
详解使用php-cs-fixer格式化代码
2020/09/16 PHP
用js实现的页面关键字密度查询代码
2007/12/27 Javascript
JavaScript高级程序设计 阅读笔记(十二) js内置对象Math
2012/08/14 Javascript
让低版本浏览器支持input的placeholder属性(js方法)
2013/04/03 Javascript
JS简单实现元素复制示例附图
2013/11/19 Javascript
jquery map方法使用示例
2014/04/23 Javascript
JS实现根据当前文字选择返回被选中的文字
2014/05/21 Javascript
原生的html元素选择器类似jquery选择器
2014/10/15 Javascript
js仿土豆网带缩略图的焦点图片切换效果实现方法
2015/02/23 Javascript
jQuery插件boxScroll实现图片轮播特效
2015/07/14 Javascript
bootstrap表格分页实例讲解
2016/12/30 Javascript
详解如何在NodeJS项目中优雅的使用ES6
2017/04/22 NodeJs
vue 设置proxyTable参数进行代理跨域
2018/04/09 Javascript
基于Koa2写个脚手架模拟接口服务的方法
2018/11/27 Javascript
小程序两种滚动公告栏的实现方法
2019/09/17 Javascript
vue+axios全局添加请求头和参数操作
2020/07/24 Javascript
[03:42]2014DOTA2西雅图国际邀请赛7月9日TOPPLAY
2014/07/09 DOTA
Python实现的快速排序算法详解
2017/08/01 Python
使用tensorflow实现线性svm
2018/09/07 Python
python 使用pandas计算累积求和的方法
2019/02/08 Python
详解Python基础random模块随机数的生成
2019/03/23 Python
对pyqt5多线程正确的开启姿势详解
2019/06/14 Python
python代码打印100-999之间的回文数示例
2019/11/24 Python
解决pycharm编辑区显示yaml文件层级结构遇中文乱码问题
2020/04/27 Python
Python分类测试代码实例汇总
2020/07/23 Python
美国女性服饰销售网站:Nasty Gal(坏女孩)
2016/07/26 全球购物
雅诗兰黛旗下专业男士保养领导品牌:Lab Series
2017/05/15 全球购物
Osklen官方在线商店:巴西服装品牌
2019/04/25 全球购物
口腔工艺技术专业毕业生自荐信
2013/09/27 职场文书
理工大学毕业生自荐信范文
2014/02/22 职场文书
幼儿园清明节活动总结
2014/07/04 职场文书
校园会短篇的广播稿
2014/10/21 职场文书
导游词之新疆-喀纳斯
2019/10/10 职场文书
vue+echarts实现多条折线图
2022/03/21 Vue.js