pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python标准库之sqlite3使用实例
Nov 25 Python
使用C语言扩展Python程序的简单入门指引
Apr 14 Python
python修改字典内key对应值的方法
Jul 11 Python
Python cookbook(字符串与文本)在字符串的开头或结尾处进行文本匹配操作
Apr 20 Python
浅谈PySpark SQL 相关知识介绍
Jun 14 Python
python接口自动化如何封装获取常量的类
Dec 24 Python
使用python实现希尔、计数、基数基础排序的代码
Dec 25 Python
手动安装python3.6的操作过程详解
Jan 13 Python
Python 字典中的所有方法及用法
Jun 10 Python
python可以用哪些数据库
Jun 22 Python
python实现自动清理文件夹旧文件
May 10 Python
python实现自定义日志的具体方法
May 28 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
用PHP与XML联手进行网站编程代码实例
2008/07/10 PHP
PDO::errorInfo讲解
2019/01/28 PHP
laravel 执行迁移回滚示例
2019/10/23 PHP
显示、隐藏密码
2006/07/01 Javascript
JS中动态添加事件(绑定事件)的代码
2011/01/09 Javascript
jQuery 获取URL的GET参数值的小例子
2013/04/18 Javascript
js 自动播放的实例代码
2013/11/19 Javascript
ie9 提示'console' 未定义问题的解决方法
2014/03/20 Javascript
Node.js返回JSONP详解
2016/05/18 Javascript
vue-resource + json-server模拟数据的方法
2017/11/02 Javascript
vue.js 实现图片本地预览 裁剪 压缩 上传功能
2018/03/01 Javascript
微信小程序左滑删除功能开发案例详解
2018/11/12 Javascript
使用vue制作滑动标签
2019/09/21 Javascript
vue中动态select的使用方法示例
2019/10/28 Javascript
OpenLayers3实现鼠标移动显示坐标
2020/09/25 Javascript
antd日期选择器禁止选择当天之前的时间操作
2020/10/29 Javascript
[14:00]DOTA2国际邀请赛史上最长大战 赛后专访B神
2013/08/10 DOTA
Python中字符串格式化str.format的详细介绍
2017/02/17 Python
Python计算开方、立方、圆周率,精确到小数点后任意位的方法
2018/07/17 Python
python利用7z批量解压rar的实现
2019/08/07 Python
CSS3中颜色线性渐变实战
2015/07/18 HTML / CSS
浅析canvas元素的html尺寸和css尺寸对元素视觉的影响
2019/07/22 HTML / CSS
Lookfantastic瑞典:英国知名美妆购物网站
2018/04/06 全球购物
Mio Skincare英国官网:身体紧致及孕期身体护理
2018/08/19 全球购物
BLACKMORES澳洲官网:澳大利亚排名第一的保健品牌
2018/09/27 全球购物
什么是命名空间(NameSpace)
2015/11/24 面试题
swtich是否能作用在byte上,是否能作用在long上,是否能作用在String上?
2013/03/30 面试题
大学生党课思想汇报
2013/12/29 职场文书
劳动之星获奖感言
2014/02/01 职场文书
社区清明节活动总结
2014/07/04 职场文书
民政局副局长民主生活会个人整改措施
2014/10/04 职场文书
解决MySQL存储时间出现不一致的问题
2021/04/28 MySQL
PostgreSQL13基于流复制搭建后备服务器的方法
2022/01/18 PostgreSQL
InterProcessMutex实现zookeeper分布式锁原理
2022/03/21 Java/Android
阿里面试Nacos配置中心交互模型是push还是pull原理解析
2022/07/23 Java/Android
nginx配置指令之server_name的具体使用
2022/08/14 Servers