pytorch交叉熵损失函数的weight参数的使用


Posted in Python onMay 24, 2021

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

pytorch交叉熵损失函数的weight参数的使用

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则匹配抓取豆瓣电影链接和评论代码分享
Dec 27 Python
Python中对列表排序实例
Jan 04 Python
python数据清洗系列之字符串处理详解
Feb 12 Python
python读取二进制mnist实例详解
May 31 Python
CentOS7下python3.7.0安装教程
Jul 30 Python
使用PyQt4 设置TextEdit背景的方法
Jun 14 Python
ipad上运行python的方法步骤
Oct 12 Python
Python写出新冠状病毒确诊人数地图的方法
Feb 12 Python
python GUI库图形界面开发之PyQt5表格控件QTableView详细使用方法与实例
Mar 01 Python
python保留格式汇总各部门excel内容的实现思路
Jun 01 Python
python打开音乐文件的实例方法
Jul 21 Python
基于Python正确读取资源文件
Sep 14 Python
pytorch 实现变分自动编码器的操作
May 24 #Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 #Python
Python基础之函数嵌套知识总结
May 23 #Python
利用python Pandas实现批量拆分Excel与合并Excel
May 23 #Python
Python基础之元编程知识总结
May 23 #Python
Python利用folium实现地图可视化
python爬虫之selenium库的安装及使用教程
You might like
js parentElement和offsetParent之间的区别
2010/03/23 Javascript
jQuery遍历Table应用示例
2014/04/09 Javascript
Lab.js初次使用笔记
2015/02/28 Javascript
js监听键盘事件的方法_原生和jquery的区别详解
2016/10/10 Javascript
Vue数组更新及过滤排序功能
2017/08/10 Javascript
vue.js vue-router如何实现无效路由(404)的友好提示
2017/12/20 Javascript
详解Vue单元测试case写法
2018/05/24 Javascript
Angular搜索场景中使用rxjs的操作符处理思路
2018/05/30 Javascript
微信小程序scroll-view横向滑动嵌套for循环的示例代码
2018/09/20 Javascript
JavaScript数组、json对象、eval()函数用法实例分析
2019/02/21 Javascript
mongodb初始化并使用node.js实现mongodb操作封装方法
2019/04/02 Javascript
JQuery+Bootstrap 自定义全屏Loading插件的示例demo
2019/07/03 jQuery
Angular如何由模板生成DOM树的方法
2019/12/23 Javascript
JavaScript实现单点登录的示例
2020/09/23 Javascript
一些Python中的二维数组的操作方法
2015/05/02 Python
Python 中的 else详解
2016/04/23 Python
python计算两个地址之间的距离方法
2018/06/09 Python
python计算两个矩形框重合百分比的实例
2018/11/07 Python
详解python多线程之间的同步(一)
2019/04/03 Python
Python增强赋值和共享引用注意事项小结
2019/05/28 Python
基于python操作ES实例详解
2019/11/16 Python
Pytorch中膨胀卷积的用法详解
2020/01/07 Python
使用Python来做一个屏幕录制工具的操作代码
2020/01/18 Python
解决margin 外边距合并问题
2019/07/03 HTML / CSS
SQL Server的固定数据库角色都有哪些?对应的服务器权限有哪些?
2013/05/18 面试题
一份软件工程师的面试试题
2016/02/01 面试题
药剂专业学生求职信范文
2013/12/28 职场文书
化妆品店促销方案
2014/02/24 职场文书
物理学专业求职信
2014/07/04 职场文书
党性锻炼的心得体会
2014/09/03 职场文书
银行贷款收入证明
2014/10/17 职场文书
大学生考试作弊被抓检讨书
2014/12/27 职场文书
2015年卫生监督工作总结
2015/05/21 职场文书
Python源码解析之List
2021/05/21 Python
7个关于Python的经典基础案例
2021/11/07 Python
nginx中封禁ip和允许内网ip访问的实现示例
2022/03/17 Servers