pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python通过websocket与js客户端通信示例分析
Jun 25 Python
python简单文本处理的方法
Jul 10 Python
Python在Windows和在Linux下调用动态链接库的教程
Aug 18 Python
python difflib模块示例讲解
Sep 13 Python
Python实现上下班抢个顺风单脚本
Feb 07 Python
python3.4.3下逐行读入txt文本并去重的方法
Apr 29 Python
用Python中的turtle模块画图两只小羊方法
Apr 09 Python
python aiohttp的使用详解
Jun 20 Python
Pytorch之保存读取模型实例
Dec 30 Python
pytorch实现focal loss的两种方式小结
Jan 02 Python
Python下使用Trackbar实现绘图板
Oct 27 Python
Python快速实现一键抠图功能的全过程
Jun 29 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
php 生成WML页面方法详解
2009/08/09 PHP
解决php使用异步调用获取数据时出现(错误c00ce56e导致此项操作无法完成)
2013/07/03 PHP
PHP Opcache安装和配置方法介绍
2015/05/28 PHP
Linux系统下PHP-FPM的安装和配置教程
2015/08/17 PHP
PHP检测用户是否关闭浏览器的方法
2016/02/14 PHP
浅谈PHP eval()函数定义和用法
2016/06/21 PHP
PHP的mysqli_sqlstate()函数讲解
2019/01/23 PHP
javascript跟随滚动效果插件代码(javascript Follow Plugin)
2013/08/03 Javascript
js字符串日期yyyy-MM-dd转化为date示例代码
2014/03/06 Javascript
复制网页内容,粘贴之后自动加上网址的实现方法(脚本之家特别整理)
2014/10/16 Javascript
node.js中的path.normalize方法使用说明
2014/12/08 Javascript
JavaScript代码实现左右上下自动晃动自动移动
2016/04/08 Javascript
JavaScript操作选择对象的简单实例
2016/05/16 Javascript
jQuery.cookie.js实现记录最近浏览过的商品功能示例
2017/01/23 Javascript
vue 使用自定义指令实现表单校验的方法
2018/08/28 Javascript
关于vue v-for 循环问题(一行显示四个,每一行的最右边那个计算属性)
2018/09/04 Javascript
vue中使用protobuf的过程记录
2018/10/26 Javascript
vue项目设置scrollTop不起作用(总结)
2018/12/21 Javascript
vue动画效果实现方法示例
2019/03/18 Javascript
Python入门篇之字典
2014/10/17 Python
Python中使用Boolean操作符做真值测试实例
2015/01/30 Python
简单谈谈Python中的json与pickle
2017/07/19 Python
关于Python中浮点数精度处理的技巧总结
2017/08/10 Python
Python实现决策树C4.5算法的示例
2018/05/30 Python
python中datetime模块中strftime/strptime函数的使用
2018/07/03 Python
对python指数、幂数拟合curve_fit详解
2018/12/29 Python
Python图像处理库PIL的ImageDraw模块介绍详解
2020/02/26 Python
pandas抽取行列数据的几种方法
2020/12/13 Python
HTML5 body设置全屏背景图片的示例代码
2020/12/08 HTML / CSS
Muziker英国:中欧最大的音乐家商店
2020/02/05 全球购物
什么是java序列化,如何实现java序列化
2012/11/14 面试题
工商治理实习生的自我评价
2014/01/15 职场文书
学习党章的体会
2014/11/07 职场文书
MySQL时间设置注意事项的深入总结
2021/05/06 MySQL
MySQL中几种插入和批量语句实例详解
2021/09/14 MySQL
Windows Server 版本 20H2 于 8 月 9 日停止支持,Win10 版本 21H1 将于 12 月结束支
2022/07/23 数码科技