pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python脚本实现网卡流量监控
Feb 14 Python
Python中字典的基本知识初步介绍
May 21 Python
Python的Asyncore异步Socket模块及实现端口转发的例子
Jun 14 Python
Python基于回溯法子集树模板解决取物搭配问题实例
Sep 02 Python
Python中矩阵创建和矩阵运算方法
Aug 04 Python
Python SQL查询并生成json文件操作示例
Aug 17 Python
对python xlrd读取datetime类型数据的方法详解
Dec 26 Python
Python之指数与E记法的区别详解
Nov 21 Python
Python基于Socket实现简单聊天室
Feb 17 Python
python实现超级玛丽游戏
Mar 18 Python
简单了解Java Netty Reactor三种线程模型
Apr 26 Python
Python接口测试文件上传实例解析
May 22 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
PHP合并两个数组的两种方式的异同
2012/09/14 PHP
PHP中预定义的6种接口介绍
2015/05/12 PHP
CI分页类首页、尾页不显示的解决方法
2016/03/28 PHP
thinkPHP5实现的查询数据库并返回json数据实例
2017/10/23 PHP
激活 ActiveX 控件
2006/10/09 Javascript
详解new function(){}和function(){}() 区别分析
2008/03/22 Javascript
JQuery入门——移除绑定事件unbind方法概述及应用
2013/02/05 Javascript
jQuery中dequeue()方法用法实例
2014/12/29 Javascript
jQuery中Ajax的load方法详解
2015/01/14 Javascript
JS实现文字放大效果的方法
2015/03/03 Javascript
老生常谈 js中this的指向
2016/06/30 Javascript
ES6入门教程之Class和Module详解
2017/05/17 Javascript
Bootstrap Table使用整理(二)
2017/06/09 Javascript
详解JavaScript中的强制类型转换
2019/04/15 Javascript
微信小程序页面上下滚动效果
2020/11/18 Javascript
基于ssm框架实现layui分页效果
2019/07/27 Javascript
[54:51]Ti4 冒泡赛第二轮LGD vs C9 3
2014/07/14 DOTA
Python使用matplotlib填充图形指定区域代码示例
2018/01/16 Python
django文档学习之applications使用详解
2018/01/29 Python
Python 12306抢火车票脚本 Python京东抢手机脚本
2018/02/06 Python
python版本的仿windows计划任务工具
2018/04/30 Python
详解将Pandas中的DataFrame类型转换成Numpy中array类型的三种方法
2019/07/06 Python
scrapy实践之翻页爬取的实现
2021/01/05 Python
canvas学习总结三之绘制路径-线段
2019/01/31 HTML / CSS
兰蔻加拿大官方网站:Lancome加拿大
2016/08/05 全球购物
为世界各地的女性设计和生产时尚服装:ROMWE
2016/09/17 全球购物
size?荷兰官方网站:英国高级运动鞋精品店
2020/07/24 全球购物
js实现弹框效果
2021/03/24 Javascript
咖啡蛋糕店创业计划书
2014/01/28 职场文书
中国梦主题教育活动总结
2014/05/05 职场文书
村庄环境整治方案
2014/05/15 职场文书
党员干部批评与自我批评反四风思想汇报
2014/09/21 职场文书
2014年村委会工作总结
2014/11/24 职场文书
多属性、多分类MySQL模式设计
2021/04/05 MySQL
Python入门之基础语法详解
2021/05/11 Python
oracle覆盖导入dmp文件的2种方法
2021/05/21 Oracle