关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现k均值算法示例(k均值聚类算法)
Mar 16 Python
举例讲解Python中的身份运算符的使用方法
Oct 13 Python
Python中元组,列表,字典的区别
May 21 Python
python psutil库安装教程
Mar 19 Python
利用python和ffmpeg 批量将其他图片转换为.yuv格式的方法
Jan 08 Python
python3编写ThinkPHP命令执行Getshell的方法
Feb 26 Python
Pycharm如何打断点的方法步骤
Jun 13 Python
PyQt5 QListWidget选择多项并返回的实例
Jun 17 Python
在windows下使用python进行串口通讯的方法
Jul 02 Python
Python使用Opencv实现图像特征检测与匹配的方法
Oct 30 Python
Python基于百度AI实现OCR文字识别
Apr 02 Python
如何在windows下安装配置python工具Ulipad
Oct 27 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
PHP4中session登录页面的应用
2008/07/25 PHP
PHP的autoload机制的实现解析
2012/09/15 PHP
PHP命令行执行整合pathinfo模拟定时任务实例
2016/08/12 PHP
Yii2 队列 shmilyzxt/yii2-queue 简单概述
2017/08/02 PHP
如何用javascript判断录入的日期是否合法
2007/01/08 Javascript
用JS实现一个页面多个css样式实现
2008/05/29 Javascript
Javascript select控件操作大全(新增、修改、删除、选中、清空、判断存在等)
2008/12/19 Javascript
jQuery EasyUI API 中文文档 - Panel面板
2011/09/30 Javascript
jquery中文乱码的多种解决方法
2013/06/21 Javascript
js实现同一个页面多个渐变效果的方法
2015/04/10 Javascript
vue-resource 拦截器(interceptor)的使用详解
2017/07/04 Javascript
关于jquery form表单序列化的注意事项详解
2017/08/01 jQuery
Vue的轮播图组件实现方法
2018/03/03 Javascript
js实现简单选项卡功能
2020/03/23 Javascript
详解vue中async-await的使用误区
2018/12/05 Javascript
详解VSCode配置启动Vue项目
2019/05/14 Javascript
微信小程序实现一张或多张图片上传(云开发)
2019/09/25 Javascript
vue-cli+iview项目打包上线之后图标不显示问题及解决方法
2019/10/16 Javascript
JavaScript Date对象功能与用法学习记录
2020/04/28 Javascript
原生js实现购物车
2020/09/23 Javascript
[01:13]2014DOTA2西雅图邀请赛 舌尖上的TI4
2014/07/08 DOTA
python通过SSH登陆linux并操作的实现
2019/10/10 Python
Python语法垃圾回收机制原理解析
2020/03/25 Python
python如何停止递归
2020/09/09 Python
CSS3 倾斜的网页图片库实例教程
2009/11/14 HTML / CSS
HTML5+CSS3应用详解
2014/02/24 HTML / CSS
兰蔻英国官网:Lancome英国
2019/04/30 全球购物
业务经理岗位职责
2013/11/11 职场文书
《金孔雀轻轻跳》教学反思
2014/04/20 职场文书
五年级上册复习计划
2015/01/19 职场文书
语文复习计划
2015/01/19 职场文书
2015年村计划生育工作总结
2015/04/28 职场文书
2015迎新晚会开场白
2015/07/17 职场文书
创业计划书之农家乐
2019/10/09 职场文书
Python3 类型标注支持操作
2021/06/02 Python
Django drf请求模块源码解析
2021/06/08 Python