关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
2款Python内存检测工具介绍和使用方法
Jun 01 Python
用python删除java文件头上版权信息的方法
Jul 31 Python
python中随机函数random用法实例
Apr 30 Python
Python实现视频下载功能
Mar 14 Python
python3 pillow生成简单验证码图片的示例
Sep 19 Python
解决python3爬虫无法显示中文的问题
Apr 12 Python
Python Matplotlib实现三维数据的散点图绘制
Mar 19 Python
关于Flask项目无法使用公网IP访问的解决方式
Nov 19 Python
python标识符命名规范原理解析
Jan 10 Python
Pytorch GPU显存充足却显示out of memory的解决方式
Jan 13 Python
Python 简单计算要求形状面积的实例
Jan 18 Python
Python 连接 MySQL 的几种方法
Sep 09 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
从手册去理解分析PHP session机制
2011/07/17 PHP
写出高质量的PHP程序
2012/02/04 PHP
PHP读取PDF内容配合Xpdf的使用
2012/11/24 PHP
PHP用FTP类上传文件视频等的简单实现方法
2016/09/23 PHP
php一个文件搞定微信jssdk配置
2016/12/12 PHP
PHP操作MongoDB实现增删改查功能【附php7操作MongoDB方法】
2018/04/24 PHP
一个简单的jQuery计算器实现了连续计算功能
2014/07/21 Javascript
Jquery中offset()和position()的区别分析
2015/02/05 Javascript
node.js require() 源码解读
2015/12/13 Javascript
第一次接触JS require.js模块化工具
2016/04/17 Javascript
微信小程序 地图定位简单实例
2016/10/14 Javascript
Bootstrap基本插件学习笔记之按钮(21)
2016/12/08 Javascript
JavaScript自动点击链接 防止绕过浏览器访问的方法
2017/01/19 Javascript
jQuery.validate.js表单验证插件的使用代码详解
2018/10/22 jQuery
JQuery Ajax跨域调用和非跨域调用问题实例分析
2019/04/16 jQuery
[02:20]DOTA2英雄基础教程 黑暗贤者
2013/12/19 DOTA
[15:39]教你分分钟做大人:龙骑士
2014/10/30 DOTA
[57:24]LGD vs VGJ.T 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
简单谈谈python中的Queue与多进程
2016/08/25 Python
对python使用http、https代理的实例讲解
2018/05/07 Python
Flask框架模板继承实现方法分析
2019/07/31 Python
python3.6、opencv安装环境搭建过程(图文教程)
2019/11/05 Python
Windows下实现将Pascal VOC转化为TFRecords
2020/02/17 Python
Python Selenium实现无可视化界面过程解析
2020/08/25 Python
基于pycharm 项目和项目文件命名规则的介绍
2021/01/15 Python
SAZAC的动物连体衣和动物睡衣:Kigurumi Shop
2020/03/14 全球购物
linux面试题参考答案(2)
2015/12/06 面试题
Ruby中的保护方法和私有方法与一般面向对象程序设计语言的一样吗
2013/05/01 面试题
求职者应聘的自我评价
2013/10/16 职场文书
学习雷锋倡议书
2014/04/15 职场文书
县长群众路线对照检查材料思想汇报
2014/10/02 职场文书
2014年度个人总结范文
2015/03/09 职场文书
公司副总经理岗位职责
2015/04/08 职场文书
干部考核工作总结
2015/08/12 职场文书
2015元旦感言
2015/12/09 职场文书
Spring Data JPA使用JPQL与原生SQL进行查询的操作
2021/06/15 Java/Android