pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python下如何让web元素的生成更简单的分析
Jul 17 Python
python实现合并两个数组的方法
May 16 Python
python实现在windows服务中新建进程的方法
Jun 30 Python
python实现获取Ip归属地等信息
Aug 27 Python
PyCharm安装第三方库如Requests的图文教程
May 18 Python
Python中一些不为人知的基础技巧总结
May 19 Python
Python 通过requests实现腾讯新闻抓取爬虫的方法
Feb 22 Python
Python+OpenCv制作证件图片生成器的操作方法
Aug 21 Python
Python虚拟环境的创建和使用详解
Sep 07 Python
详解如何使用Pytest进行自动化测试
Jan 14 Python
python实现图片转字符画的完整代码
Feb 21 Python
asyncio异步编程之Task对象详解
Mar 13 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
PHP Zip解压 文件在线解压缩的函数代码
2010/05/26 PHP
学习php设计模式 php实现策略模式(strategy)
2015/12/07 PHP
js的with语句使用方法
2007/09/21 Javascript
js电信网通双线自动选择技巧
2008/11/18 Javascript
jQuery中事件对象e的事件冒泡用法示例介绍
2014/04/25 Javascript
js中利用tagname和id获取元素的方法
2016/01/03 Javascript
JS快速实现移动端拼图游戏
2016/09/05 Javascript
JS简单随机数生成方法
2016/09/05 Javascript
实例解析jQuery中如何取消后续执行内容
2016/12/01 Javascript
Vue 表单控件绑定的实现示例
2017/08/11 Javascript
Angular5中调用第三方库及jQuery的添加的方法
2018/06/07 jQuery
JavaScript设计模式之工厂模式简单实例教程
2018/07/03 Javascript
详解webpack-dev-middleware 源码解读
2020/03/23 Javascript
《javascript设计模式》学习笔记一:Javascript面向对象程序设计对象成员的定义分析
2020/04/07 Javascript
关于AngularJS中几种Providers的区别总结
2020/05/17 Javascript
Ant Design Vue table中列超长显示...并加提示语的实例
2020/10/31 Javascript
Python列表list内建函数用法实例分析【insert、remove、index、pop等】
2017/07/24 Python
python进行两个表格对比的方法
2018/06/27 Python
Python 中的range(),以及列表切片方法
2018/07/02 Python
详解python做UI界面的方法
2019/02/27 Python
快速排序的四种python实现(推荐)
2019/04/03 Python
Python将文字转成语音并读出来的实例详解
2019/07/15 Python
python爬虫中多线程的使用详解
2019/09/23 Python
python多线程案例之多任务copy文件完整实例
2019/10/29 Python
Tensorflow安装问题: Could not find a version that satisfies the requirement tensorflow
2020/04/20 Python
Python使用xpath实现图片爬取
2020/09/16 Python
pytorch 移动端部署之helloworld的使用
2020/10/30 Python
使用gunicorn部署django项目的问题
2020/12/30 Python
德国最大的服装、鞋子和配件在线商店之一:Outfits24
2019/07/23 全球购物
上级检查欢迎词
2014/01/18 职场文书
运动会开幕式邀请函
2014/02/03 职场文书
2015年销售部工作总结范文
2015/04/27 职场文书
浅谈怎么给Python添加类型标注
2021/06/08 Python
nginx实现动静分离的方法示例
2021/11/07 Servers
MySQL创建管理RANGE分区
2022/04/13 MySQL
Mac电脑OS系统下安装Nginx的详细教程
2022/04/14 Servers