关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python标准日志模块logging的使用方法
Nov 01 Python
Python基于二分查找实现求整数平方根的方法
May 12 Python
Python中一行和多行import模块问题
Apr 01 Python
利用Python实现在同一网络中的本地文件共享方法
Jun 04 Python
python版本五子棋的实现代码
Dec 11 Python
PyTorch基本数据类型(一)
May 22 Python
Python TCP通信客户端服务端代码实例
Nov 21 Python
python实现异常信息堆栈输出到日志文件
Dec 26 Python
Python实现病毒仿真器的方法示例(附demo)
Feb 19 Python
python读取当前目录下的CSV文件数据
Mar 11 Python
Django实现将views.py中的数据传递到前端html页面,并展示
Mar 16 Python
python 使用Tensorflow训练BP神经网络实现鸢尾花分类
May 12 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
PHP使用redis实现统计缓存mysql压力的方法
2015/11/14 PHP
将PHP程序中返回的JSON格式数据用gzip压缩输出的方法
2016/03/03 PHP
php自动载入类用法实例分析
2016/06/24 PHP
PHP中$GLOBALS['HTTP_RAW_POST_DATA']和$_POST的区别分析
2017/07/03 PHP
ThinkPHP框架实现的微信支付接口开发完整示例
2019/04/10 PHP
解决php用mysql方式连接数据库出现Deprecated报错问题
2019/12/25 PHP
laravel框架路由分组,中间件,命名空间,子域名,路由前缀实例分析
2020/02/18 PHP
TP框架实现上传一张图片和批量上传图片的方法分析
2020/04/23 PHP
RR vs IO BO3 第二场2.13
2021/03/10 DOTA
发布一个高效的JavaScript分析、压缩工具 JavaScript Analyser
2007/11/30 Javascript
js实现全屏漂浮广告移入光标停止移动
2013/12/02 Javascript
Jquery模仿Baidu、Google搜索时自动补充搜索结果提示
2013/12/26 Javascript
jQuery 设置 CSS 属性示例介绍
2014/01/16 Javascript
IE中鼠标经过option触发mouseout的解决方法
2015/01/29 Javascript
JavaScript数组迭代器实例分析
2015/06/09 Javascript
js字符串引用的两种方式(必看)
2016/09/18 Javascript
解决给dom元素绑定click等事件无效问题的方法
2017/02/17 Javascript
javascript实现移动端上传图片功能
2020/08/18 Javascript
js实现简单的点名器随机色实例代码
2020/09/20 Javascript
[02:10]2018DOTA2亚洲邀请赛赛前采访-Liquid
2018/04/03 DOTA
详解python分布式进程
2018/10/08 Python
对python mayavi三维绘图的实现详解
2019/01/08 Python
Python绘制并保存指定大小图像的方法
2019/01/10 Python
DJango的创建和使用详解(默认数据库sqlite3)
2019/11/18 Python
Python:合并两个numpy矩阵的实现
2019/12/02 Python
python爬虫开发之urllib模块详细使用方法与实例全解
2020/03/09 Python
Python3利用openpyxl读写Excel文件的方法实例
2021/02/03 Python
日本著名的服饰鞋帽综合类购物网站:MAGASEEK
2019/01/09 全球购物
应聘编辑自荐信范文
2014/03/12 职场文书
纪检干部现实表现材料
2014/08/21 职场文书
不尊敬老师检讨书范文
2014/11/19 职场文书
2015年监理个人工作总结
2015/05/23 职场文书
二手手机买卖合同范本(2019年版)
2019/10/28 职场文书
简述python四种分词工具,盘点哪个更好用?
2021/04/13 Python
python 如何在list中找Topk的数值和索引
2021/05/20 Python
Pandas加速代码之避免使用for循环
2021/05/30 Python