在Pytorch中使用样本权重(sample_weight)的正确方法


Posted in Python onAugust 17, 2019

step:

1.将标签转换为one-hot形式。

2.将每一个one-hot标签中的1改为预设样本权重的值

即可在Pytorch中使用样本权重。

eg:

对于单个样本:loss = - Q * log(P),如下:

P = [0.1,0.2,0.4,0.3]
Q = [0,0,1,0]
loss = -Q * np.log(P)

增加样本权重则为loss = - Q * log(P) *sample_weight

P = [0.1,0.2,0.4,0.3]
Q = [0,0,sample_weight,0]
loss_samle_weight = -Q * np.log(P)

在pytorch中示例程序

train_data = np.load(open('train_data.npy','rb'))
train_labels = []
for i in range(8):
  train_labels += [i] *100
train_labels = np.array(train_labels)
train_labels = to_categorical(train_labels).astype("float32")
sample_1 = [random.random() for i in range(len(train_data))]
for i in range(len(train_data)):
  floor = i / 100
  train_labels[i][floor] = sample_1[i]
train_data = torch.from_numpy(train_data) 
train_labels = torch.from_numpy(train_labels) 
dataset = dataf.TensorDataset(train_data,train_labels) 
trainloader = dataf.DataLoader(dataset, batch_size=batch_size, shuffle=True)

对应one-target的多分类交叉熵损失函数如下:

def my_loss(outputs, targets):
  
  output2 = outputs - torch.max(outputs, 1, True)[0]
 
 
  P = torch.exp(output2) / torch.sum(torch.exp(output2), 1,True) + 1e-10
 
 
  loss = -torch.mean(targets * torch.log(P))
 
 
  return loss

以上这篇在Pytorch中使用样本权重(sample_weight)的正确方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过urllib2获取带有中文参数url内容的方法
Mar 13 Python
Python获取央视节目单的实现代码
Jul 25 Python
详解Python 2.6 升级至 Python 2.7 的实践心得
Apr 27 Python
解决Scrapy安装错误:Microsoft Visual C++ 14.0 is required...
Oct 01 Python
Appium+python自动化怎么查看程序所占端口号和IP
Jun 14 Python
使用 Python 在京东上抢口罩的思路详解
Feb 27 Python
使用Python Tkinter实现剪刀石头布小游戏功能
Oct 23 Python
python 爬虫基本使用——统计杭电oj题目正确率并排序
Oct 26 Python
python实现在列表中查找某个元素的下标示例
Nov 16 Python
python中使用asyncio实现异步IO实例分析
Feb 26 Python
python使用pygame创建精灵Sprite
Apr 06 Python
python geopandas读取、创建shapefile文件的方法
Jun 29 Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
You might like
我的论坛源代码(五)
2006/10/09 PHP
php Undefined index的问题
2009/06/01 PHP
一个PHP分页类的代码
2011/05/18 PHP
了解PHP的返回引用和局部静态变量
2015/06/04 PHP
curl 出现错误的调试方法(必看)
2017/02/13 PHP
深入解析PHP底层机制及相关原理
2020/12/11 PHP
JQuery自定义事件的应用 JQuery最佳实践
2010/08/01 Javascript
基于jQuery的history历史记录插件
2010/12/11 Javascript
js call方法详细介绍(js 的继承)
2013/11/18 Javascript
2014 年最热门的21款JavaScript框架推荐
2014/12/25 Javascript
jquery移动节点实例
2015/01/14 Javascript
jQuery中$.each使用详解
2015/01/29 Javascript
原生JavaScript制作计算器
2016/10/16 Javascript
angular过滤器实现排序功能
2017/06/27 Javascript
Javascript别踩白块儿(钢琴块儿)小游戏实现代码
2017/07/20 Javascript
vue里面父组件修改子组件样式的方法
2018/02/03 Javascript
ES6与CommonJS中的模块处理的区别
2018/06/13 Javascript
Vue-Quill-Editor富文本编辑器的使用教程
2018/09/21 Javascript
改进 JavaScript 和 Rust 的互操作性并深入认识 wasm-bindgen 组件
2019/07/13 Javascript
Elementui表格组件+sortablejs实现行拖拽排序的示例代码
2019/08/28 Javascript
对Layer弹窗使用及返回数据接收的实例详解
2019/09/26 Javascript
原生js实现随机点名功能
2019/11/05 Javascript
JS实现图片切换特效
2019/12/23 Javascript
微信小程序实现滑动操作代码
2020/04/23 Javascript
python实现类似ftp传输文件的网络程序示例
2014/04/08 Python
python嵌套函数使用外部函数变量的方法(Python2和Python3)
2016/01/31 Python
Python 字符串操作(string替换、删除、截取、复制、连接、比较、查找、包含、大小写转换、分割等)
2018/03/19 Python
pandas通过loc生成新的列方法
2018/11/28 Python
python3学生名片管理v2.0版
2018/11/29 Python
Pytorch中index_select() 函数的实现理解
2019/11/19 Python
深入解析HTML5中的Blob对象的使用
2015/09/08 HTML / CSS
英国受欢迎的运动鞋和街头服装商店:Footasylum
2018/06/12 全球购物
消防安全员岗位职责
2014/03/10 职场文书
业务员辞职信范文
2015/03/02 职场文书
家长通知书家长意见
2015/06/03 职场文书
公司劳动纪律管理制度
2015/08/04 职场文书