pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现同时兼容老版和新版Socket协议的一个简单WebSocket服务器
Jun 04 Python
Windows下Anaconda的安装和简单使用方法
Jan 04 Python
python 爬虫 批量获取代理ip的实例代码
May 22 Python
Python实现图片拼接的代码
Jul 02 Python
Django分页查询并返回jsons数据(中文乱码解决方法)
Aug 02 Python
Python 单元测试(unittest)的使用小结
Nov 14 Python
pycharm远程开发项目的实现步骤
Jan 20 Python
Python+OpenCV图片局部区域像素值处理改进版详解
Jan 23 Python
Python中的 is 和 == 以及字符串驻留机制详解
Jun 28 Python
python 使用递归的方式实现语义图片分割功能
Jul 16 Python
最新Python idle下载、安装与使用教程图文详解
Nov 28 Python
python中类与对象之间的关系详解
Dec 16 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
一条久听不愿放下的DIY森海MX500,三言两语话神奇
2021/03/02 无线电
thinkphp3.2实现上传图片的控制器方法
2016/04/28 PHP
PHP实现的XXTEA加密解密算法示例
2018/08/28 PHP
你可能不再需要JQUERY
2021/03/09 Javascript
接收键盘指令的脚本
2006/06/26 Javascript
Javascript Math ceil()、floor()、round()三个函数的区别
2010/03/09 Javascript
jquery模拟alert的弹窗插件
2015/07/31 Javascript
JavaScript实现动态删除列表框值的方法
2015/08/12 Javascript
js变量提升深入理解
2016/09/16 Javascript
Bootstrap模态框禁用空白处点击关闭
2016/10/20 Javascript
Vue Ajax跨域请求实例详解
2017/06/20 Javascript
r.js来合并压缩css文件的示例
2018/04/26 Javascript
讲解vue-router之什么是嵌套路由
2018/05/28 Javascript
从零开始用electron手撸一个截屏工具的示例代码
2018/10/10 Javascript
中级前端工程师必须要掌握的27个JavaScript 技巧(干货总结)
2019/09/23 Javascript
IntelliJ IDEA编辑器配置vue高亮显示
2019/09/26 Javascript
[01:48]完美圣典齐天大圣至宝宣传片
2016/12/17 DOTA
使用django-suit为django 1.7 admin后台添加模板
2014/11/18 Python
再谈Python中的字符串与字符编码(推荐)
2016/12/14 Python
Python 使用with上下文实现计时功能
2018/03/09 Python
Python实现定时精度可调节的定时器
2018/04/15 Python
python语音识别实践之百度语音API
2018/08/30 Python
详解django中url路由配置及渲染方式
2019/02/25 Python
彻底理解Python中的yield关键字
2019/04/01 Python
python如何获取列表中每个元素的下标位置
2019/07/01 Python
pytorch的batch normalize使用详解
2020/01/15 Python
基于python实现监听Rabbitmq系统日志代码示例
2020/11/28 Python
iostream与iostream.h的区别
2015/01/16 面试题
对象的序列化(serialization)类是面向流的,应如何将对象写入到随机存取文件中
2015/06/22 面试题
人事专员岗位职责说明书
2014/07/30 职场文书
股东授权委托书
2014/10/15 职场文书
节约用电通知
2015/04/25 职场文书
公安纪律作风整顿心得体会
2016/01/23 职场文书
MySQL创建索引需要了解的
2021/04/08 MySQL
Ajax请求超时与网络异常处理图文详解
2021/05/23 Javascript
原型和原型链 prototype和proto的区别详情
2021/11/02 Javascript