在Pytorch中使用样本权重(sample_weight)的正确方法


Posted in Python onAugust 17, 2019

step:

1.将标签转换为one-hot形式。

2.将每一个one-hot标签中的1改为预设样本权重的值

即可在Pytorch中使用样本权重。

eg:

对于单个样本:loss = - Q * log(P),如下:

P = [0.1,0.2,0.4,0.3]
Q = [0,0,1,0]
loss = -Q * np.log(P)

增加样本权重则为loss = - Q * log(P) *sample_weight

P = [0.1,0.2,0.4,0.3]
Q = [0,0,sample_weight,0]
loss_samle_weight = -Q * np.log(P)

在pytorch中示例程序

train_data = np.load(open('train_data.npy','rb'))
train_labels = []
for i in range(8):
  train_labels += [i] *100
train_labels = np.array(train_labels)
train_labels = to_categorical(train_labels).astype("float32")
sample_1 = [random.random() for i in range(len(train_data))]
for i in range(len(train_data)):
  floor = i / 100
  train_labels[i][floor] = sample_1[i]
train_data = torch.from_numpy(train_data) 
train_labels = torch.from_numpy(train_labels) 
dataset = dataf.TensorDataset(train_data,train_labels) 
trainloader = dataf.DataLoader(dataset, batch_size=batch_size, shuffle=True)

对应one-target的多分类交叉熵损失函数如下:

def my_loss(outputs, targets):
  
  output2 = outputs - torch.max(outputs, 1, True)[0]
 
 
  P = torch.exp(output2) / torch.sum(torch.exp(output2), 1,True) + 1e-10
 
 
  loss = -torch.mean(targets * torch.log(P))
 
 
  return loss

以上这篇在Pytorch中使用样本权重(sample_weight)的正确方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用beautifulsoup从爱奇艺网抓取视频播放
Jan 23 Python
python实现查询IP地址所在地
Mar 29 Python
在Python的Django框架中加载模版的方法
Jul 16 Python
Python正则表达式教程之一:基础篇
Mar 02 Python
利用python模拟sql语句对员工表格进行增删改查
Jul 05 Python
Python分治法定义与应用实例详解
Jul 28 Python
Python画柱状统计图操作示例【基于matplotlib库】
Jul 04 Python
Python字符串逆序输出的实例讲解
Feb 16 Python
Python3中列表list合并的四种方法
Apr 19 Python
Python二进制文件读取并转换为浮点数详解
Jun 25 Python
简单了解python协程的相关知识
Aug 31 Python
python随机数分布random均匀分布实例
Nov 27 Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
You might like
PHP与MySQL开发中页面出现乱码的一种解决方法
2007/07/29 PHP
基于php socket(fsockopen)的应用实例分析
2013/06/02 PHP
php算法实例分享
2015/07/14 PHP
php源码的使用方法讲解
2019/09/26 PHP
php高性能日志系统 seaslog 的安装与使用方法分析
2020/02/29 PHP
javascript编程起步(第七课)
2007/01/10 Javascript
jquery tools系列 expose 学习
2009/09/06 Javascript
JavaScript中的toLocaleDateString()方法使用简介
2015/06/12 Javascript
Node.js实现JS文件合并小工具
2016/02/02 Javascript
javascript 继承学习心得总结
2016/03/17 Javascript
javascript创建对象的几种模式介绍
2016/05/06 Javascript
jQuery文字轮播特效
2017/02/12 Javascript
Vue2.x中的父子组件相互通信的实现方法
2017/05/02 Javascript
SpringMVC简单整合Angular2的示例
2017/07/31 Javascript
vue+socket.io+express+mongodb 实现简易多房间在线群聊示例
2017/10/21 Javascript
js实现简单进度条效果
2020/03/25 Javascript
vue接通后端api以及部署到服务器操作
2020/08/13 Javascript
jQuery+ajax实现用户登录验证
2020/09/13 jQuery
Python中的异常处理相关语句基础学习笔记
2016/07/11 Python
Python常见异常分类与处理方法
2017/06/04 Python
Python 查看文件的读写权限方法
2018/01/23 Python
解决python报错MemoryError的问题
2018/06/26 Python
python制作填词游戏步骤详解
2019/05/05 Python
解决pycharm安装第三方库失败的问题
2020/05/09 Python
django跳转页面传参的实现
2020/09/17 Python
C++如何引用一个已经定义过的全局变量
2014/08/25 面试题
厨房工作人员岗位职责
2013/11/15 职场文书
寒假实习自荐信
2014/01/26 职场文书
教师党性分析材料
2014/02/04 职场文书
会计专业求职信
2014/08/10 职场文书
2014年中职班主任工作总结
2014/12/16 职场文书
怎样写辞职信
2015/02/27 职场文书
MySQL如何解决幻读问题
2021/08/07 MySQL
使用pipenv管理python虚拟环境的全过程
2021/09/25 Python
面试分析分布式架构Redis热点key大Value解决方案
2022/03/13 Redis
Python万能模板案例之matplotlib绘制甘特图
2022/04/13 Python