关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现多线程采集的2个代码例子
Jul 07 Python
Python下的subprocess模块的入门指引
Apr 16 Python
Python实现在matplotlib中两个坐标轴之间画一条直线光标的方法
May 20 Python
Python编程中的for循环语句学习教程
Oct 14 Python
Python使用urllib2模块抓取HTML页面资源的实例分享
May 03 Python
Python中字典的浅拷贝与深拷贝用法实例分析
Jan 02 Python
pip安装py_zipkin时提示的SSL问题对应
Dec 29 Python
Python语法分析之字符串格式化
Jun 13 Python
python实现图像外边界跟踪操作
Jul 13 Python
python3.7调试的实例方法
Jul 21 Python
Python pip install之SSL异常处理操作
Sep 03 Python
秀!学妹看见都惊呆的Python小招数!【详细语言特性使用技巧】
Apr 27 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
PHP 伪静态隐藏传递参数名的四种方法
2010/02/22 PHP
微信公众号支付之坑:调用支付jsapi缺少参数 timeStamp等错误解决方法
2016/01/12 PHP
PHP使用文件锁解决高并发问题示例
2018/03/29 PHP
js实现页面打印功能实例代码(附去页眉页脚功能代码)
2009/12/15 Javascript
javascript开发技术大全 第4章 直接量与字符集
2011/07/03 Javascript
JavaScript使用HTML5的window.postMessage实现跨域通信例子
2014/04/11 Javascript
js检验密码强度(低中高)附图
2014/06/05 Javascript
jquery+html5制作超酷的圆盘时钟表
2015/04/14 Javascript
js获取鼠标位置实例详解
2015/12/09 Javascript
微信小程序 wx.request(object) API详解及实例代码
2016/09/30 Javascript
微信小程序使用image组件显示图片的方法【附源码下载】
2017/12/08 Javascript
VueJs监听window.resize方法示例
2018/01/17 Javascript
vue cli3.0 引入eslint 结合vscode使用
2019/05/27 Javascript
vue实现登录页面的验证码以及验证过程解析(面向新手)
2019/08/02 Javascript
js实现数据导出为EXCEL(支持大量数据导出)
2020/03/31 Javascript
解决vue的touchStart事件及click事件冲突问题
2020/07/21 Javascript
vue深度监听(监听对象和数组的改变)与立即执行监听实例
2020/09/04 Javascript
原生js实现下拉框选择组件
2021/01/20 Javascript
python简单实现刷新智联简历
2016/03/30 Python
Python实现快速排序算法及去重的快速排序的简单示例
2016/06/26 Python
Python基于递归算法实现的走迷宫问题
2017/08/04 Python
selenium+python实现1688网站验证码图片的截取功能
2018/08/14 Python
Python Cookie 读取和保存方法
2018/12/28 Python
Django CBV与FBV原理及实例详解
2019/08/12 Python
python中用ctypes模拟点击的实例讲解
2020/11/26 Python
印尼综合在线预订网站:Tiket.com(机票、酒店、火车、租车和娱乐)
2018/10/11 全球购物
Timberland德国官网:靴子、鞋子、衣服、夹克及配件
2019/12/10 全球购物
投标邀请书范文
2014/01/31 职场文书
财产公证书样本
2014/04/04 职场文书
团代会宣传工作方案
2014/05/08 职场文书
加油口号大全
2014/06/13 职场文书
2014年社区党建工作汇报材料
2014/11/02 职场文书
违反学校规则制度检讨书
2015/01/01 职场文书
看上去很美观后感
2015/06/10 职场文书
pycharm无法导入lxml的解决办法
2021/03/31 Python
如何用JavaScript学习算法复杂度
2021/04/30 Javascript