pytorch实现focal loss的两种方式小结


Posted in Python onJanuary 02, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
'''
pytorch实现focal loss的两种方式(现在讨论的是基于分割任务)
在计算损失函数的过程中考虑到类别不平衡的问题,假设加上背景类别共有6个类别
'''
def compute_class_weights(histogram):
  classWeights = np.ones(6, dtype=np.float32)
  normHist = histogram / np.sum(histogram)
  for i in range(6):
    classWeights[i] = 1 / (np.log(1.10 + normHist[i]))
  return classWeights
def focal_loss_my(input,target):
  '''
  :param input: shape [batch_size,num_classes,H,W] 仅仅经过卷积操作后的输出,并没有经过任何激活函数的作用
  :param target: shape [batch_size,H,W]
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)
  '''
  根据当前给出的ground truth label计算出每个类别所占据的权重
  '''

  # weights=torch.from_numpy(classWeights).float().cuda()
  weights = torch.from_numpy(classWeights).float()
  focal_frequency = F.nll_loss(F.softmax(input, dim=1), target, reduction='none')
  '''
  上面一篇博文讲过
  F.nll_loss(torch.log(F.softmax(inputs, dim=1),target)的函数功能与F.cross_entropy相同
  可见F.nll_loss中实现了对于target的one-hot encoding编码功能,将其编码成与input shape相同的tensor
  然后与前面那一项(即F.nll_loss输入的第一项)进行 element-wise production
  相当于取出了 log(p_gt)即当前样本点被分类为正确类别的概率
  现在去掉取log的操作,相当于 focal_frequency shape [num_samples]
  即取出ground truth类别的概率数值,并取了负号
  '''

  focal_frequency += 1.0#shape [num_samples] 1-P(gt_classes)

  focal_frequency = torch.pow(focal_frequency, 2) # torch.Size([75])
  focal_frequency = focal_frequency.repeat(c, 1)
  '''
  进行repeat操作后,focal_frequency shape [num_classes,num_samples]
  '''
  focal_frequency = focal_frequency.transpose(1, 0)
  loss = F.nll_loss(focal_frequency * (torch.log(F.softmax(input, dim=1))), target, weight=None,
           reduction='elementwise_mean')
  return loss


def focal_loss_zhihu(input, target):
  '''
  :param input: 使用知乎上面大神给出的方案 https://zhuanlan.zhihu.com/p/28527749
  :param target:
  :return:
  '''
  n, c, h, w = input.size()

  target = target.long()
  inputs = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
  target = target.contiguous().view(-1)

  N = inputs.size(0)
  C = inputs.size(1)

  number_0 = torch.sum(target == 0).item()
  number_1 = torch.sum(target == 1).item()
  number_2 = torch.sum(target == 2).item()
  number_3 = torch.sum(target == 3).item()
  number_4 = torch.sum(target == 4).item()
  number_5 = torch.sum(target == 5).item()

  frequency = torch.tensor((number_0, number_1, number_2, number_3, number_4, number_5), dtype=torch.float32)
  frequency = frequency.numpy()
  classWeights = compute_class_weights(frequency)

  weights = torch.from_numpy(classWeights).float()
  weights=weights[target.view(-1)]#这行代码非常重要

  gamma = 2

  P = F.softmax(inputs, dim=1)#shape [num_samples,num_classes]

  class_mask = inputs.data.new(N, C).fill_(0)
  class_mask = Variable(class_mask)
  ids = target.view(-1, 1)
  class_mask.scatter_(1, ids.data, 1.)#shape [num_samples,num_classes] one-hot encoding

  probs = (P * class_mask).sum(1).view(-1, 1)#shape [num_samples,]
  log_p = probs.log()

  print('in calculating batch_loss',weights.shape,probs.shape,log_p.shape)

  # batch_loss = -weights * (torch.pow((1 - probs), gamma)) * log_p
  batch_loss = -(torch.pow((1 - probs), gamma)) * log_p

  print(batch_loss.shape)

  loss = batch_loss.mean()
  return loss

if __name__=='__main__':
  pred=torch.rand((2,6,5,5))
  y=torch.from_numpy(np.random.randint(0,6,(2,5,5)))
  loss1=focal_loss_my(pred,y)
  loss2=focal_loss_zhihu(pred,y)

  print('loss1',loss1)
  print('loss2', loss2)
'''
in calculating batch_loss torch.Size([50]) torch.Size([50, 1]) torch.Size([50, 1])
torch.Size([50, 1])
loss1 tensor(1.3166)
loss2 tensor(1.3166)
'''

以上这篇pytorch实现focal loss的两种方式小结就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python3百度指数抓取实例
Dec 12 Python
python实现微信接口(itchat)详细介绍
Oct 23 Python
python如何重载模块实例解析
Jan 25 Python
学生信息管理系统Python面向对象版
Jan 30 Python
python文件写入write()的操作
May 14 Python
学习python分支结构
May 17 Python
python爬虫之快速对js内容进行破解
Jul 09 Python
django创建最简单HTML页面跳转方法
Aug 16 Python
新手入门学习python Numpy基础操作
Mar 02 Python
jupyter使用自动补全和切换默认浏览器的方法
Nov 18 Python
Python 随机按键模拟2小时
Dec 30 Python
Python实现为PDF去除水印的示例代码
Apr 03 Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
Python加密模块的hashlib,hmac模块使用解析
Jan 02 #Python
在win64上使用bypy进行百度网盘文件上传功能
Jan 02 #Python
pytorch实现onehot编码转为普通label标签
Jan 02 #Python
You might like
PHP读取txt文件的内容并赋值给数组的代码
2011/11/03 PHP
php实现仿写CodeIgniter的购物车类
2015/07/29 PHP
php实现小程序支付完整版
2018/10/09 PHP
简短几句 通俗解释javascript的闭包
2011/01/17 Javascript
设为首页加入收藏兼容360/火狐/谷歌/IE等主流浏览器的代码
2013/03/26 Javascript
javascript:void(0)是什么意思示例介绍
2013/11/17 Javascript
jQuery遍历之next()、nextAll()方法使用实例
2014/11/08 Javascript
jquery+ajax实现直接提交表单实例分析
2016/06/17 Javascript
jQuery实现带遮罩层效果的blockUI弹出层示例【附demo源码下载】
2016/09/14 Javascript
jQuery实现的小图列表,大图展示效果幻灯片示例
2016/10/25 Javascript
JavaScript数据结构之数组的表示方法示例
2017/04/12 Javascript
addeventlistener监听scroll跟touch(实例讲解)
2017/08/04 Javascript
Vue的实例、生命周期与Vue脚手架(vue-cli)实例详解
2017/12/27 Javascript
node错误处理与日志记录的实现
2018/12/24 Javascript
weui中的picker使用js进行动态绑定数据问题
2019/11/06 Javascript
Python处理文本文件中控制字符的方法
2017/02/07 Python
python利用有道翻译实现"语言翻译器"的功能实例
2017/11/14 Python
python学习基础之循环import及import过程
2018/04/22 Python
Python cv2 图像自适应灰度直方图均衡化处理方法
2018/12/07 Python
对pytorch网络层结构的数组化详解
2018/12/08 Python
使用python opencv对目录下图片进行去重的方法
2019/01/12 Python
Python 静态方法和类方法实例分析
2019/11/21 Python
Python中如何将一个类方法变为多个方法
2019/12/30 Python
python使用Geany编辑器配置方法
2020/02/21 Python
Python类和实例的属性机制原理详解
2020/03/21 Python
Jupyter notebook 远程配置及SSL加密教程
2020/04/14 Python
python实现程序重启和系统重启方式
2020/04/16 Python
基于Python脚本实现邮件报警功能
2020/05/20 Python
英国Amara家居法国网站:家居装饰,现代装饰和豪华礼品
2016/12/15 全球购物
科茨沃尔德家居商店:Scotts of Stow
2018/06/29 全球购物
李维斯牛仔裤英国官方网站:Levi’s英国
2019/10/10 全球购物
中专生求职自荐信范文
2013/12/22 职场文书
初中升旗仪式演讲稿
2014/05/08 职场文书
公安机关查摆剖析材料
2014/10/10 职场文书
SpringBoot 集成Redis 过程
2021/06/02 Redis
Spark SQL 2.4.8 操作 Dataframe的两种方式
2021/10/16 SQL Server