pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(七):HTML和XHTML解析(HTMLParser、BeautifulSoup)
Jun 09 Python
Python struct模块解析
Jun 12 Python
Python实现简单的代理服务器
Jul 25 Python
python爬虫入门教程--HTML文本的解析库BeautifulSoup(四)
May 25 Python
python web.py开发httpserver解决跨域问题实例解析
Feb 12 Python
在cmd命令行里进入和退出Python程序的方法
May 12 Python
Python字符串匹配之6种方法的使用详解
Apr 08 Python
pyqt5使用按钮进行界面的跳转方法
Jun 19 Python
Python 解决OPEN读文件报错 ,路径以及r的问题
Dec 19 Python
Python3 使用selenium插件爬取苏宁商家联系电话
Dec 23 Python
通过代码实例解析Pytest运行流程
Aug 20 Python
python区块链实现简版工作量证明
May 25 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
php array_merge下进行数组合并的代码
2008/07/22 PHP
php在数据库抽象层简单使用PDO的方法
2015/11/03 PHP
PHP实现的一致性哈希算法完整实例
2015/11/14 PHP
php自定义截取中文字符串-utf8版
2017/02/27 PHP
JavaScript 继承的实现
2009/07/09 Javascript
javascript 45种缓动效果 非常酷
2011/06/28 Javascript
javascript的数据类型、字面量、变量介绍
2012/05/23 Javascript
showModelDialog弹出文件下载窗口的使用示例
2013/11/19 Javascript
关闭浏览器输入框自动补齐 兼容IE,FF,Chrome等主流浏览器
2014/02/11 Javascript
JS实现简单路由器功能的方法
2015/05/27 Javascript
javascript中caller和callee详解
2015/08/10 Javascript
深入浅析JS的数组遍历方法(推荐)
2016/06/15 Javascript
使用React实现轮播效果组件示例代码
2016/09/05 Javascript
Node.js连接mongodb实例代码
2017/06/06 Javascript
在Debian(Raspberry Pi)树莓派上安装NodeJS的教程详解
2017/09/19 NodeJs
浅谈Vue SPA 首屏加载优化实践
2017/12/15 Javascript
JavaScript设计模式之命令模式实例分析
2019/01/16 Javascript
nuxt中使用路由守卫的方法步骤
2019/01/27 Javascript
详解vue中$nextTick和$forceUpdate的用法
2019/12/11 Javascript
[00:15]TI9观赛名额抽取
2019/07/10 DOTA
python爬虫入门教程--优雅的HTTP库requests(二)
2017/05/25 Python
python利用正则表达式搜索单词示例代码
2017/09/24 Python
pandas 实现将重复表格去重,并重新转换为表格的方法
2018/04/18 Python
快速查找Python安装路径方法
2020/02/06 Python
解决Pycharm 导入其他文件夹源码的2种方法
2020/02/12 Python
一文了解python 3 字符串格式化 F-string 用法
2020/03/04 Python
Python接口测试数据库封装实现原理
2020/05/09 Python
俄罗斯和世界各地的酒店预订:Hotels.com俄罗斯
2016/08/19 全球购物
描述RIP和OSPF区别以及特点
2015/01/17 面试题
教育课题研究自我鉴定范文
2013/12/28 职场文书
元旦晚会邀请函
2014/02/01 职场文书
2014年元旦联欢会活动策划方案
2014/02/16 职场文书
新闻专业毕业生英文求职信
2014/03/19 职场文书
银行党的群众路线教育实践活动对照检查材料
2014/09/25 职场文书
欢迎家长标语
2014/10/08 职场文书
2015年教学工作总结
2015/04/02 职场文书