关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细介绍Ruby中的正则表达式
Apr 10 Python
Python编程判断一个正整数是否为素数的方法
Apr 14 Python
python 环境变量和import模块导入方法(详解)
Jul 11 Python
Python实现检测文件MD5值的方法示例
Apr 11 Python
Python中XlsxWriter模块简介与用法分析
Apr 24 Python
python取数作为临时极大值(极小值)的方法
Oct 15 Python
pymysql模块的使用(增删改查)详解
Sep 09 Python
pytorch中torch.max和Tensor.view函数用法详解
Jan 03 Python
Python实现给PDF添加水印的方法
Jan 25 Python
利用Python批量识别电子账单数据的方法
Feb 08 Python
Python3中PyQt5简单实现文件打开及保存
Jun 10 Python
Python 装饰器(decorator)常用的创建方式及解析
Apr 24 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
PHP 5.0 Pear安装方法
2006/12/06 PHP
给WordPress中的留言加上楼层号的PHP代码实例
2015/12/14 PHP
php获取网站根目录物理路径的几种方法(推荐)
2017/03/04 PHP
Laravel 使用查询构造器配合原生sql语句查询的例子
2019/10/12 PHP
ie和firefox不兼容的解决方法集合
2009/04/28 Javascript
Checbox的操作含已选、未选及判断代码
2013/11/07 Javascript
动态加载脚本提升javascript性能
2014/02/24 Javascript
两种方法实现在HTML页面加载完毕后运行某个js
2014/06/16 Javascript
nodejs npm package.json中文文档
2014/09/04 NodeJs
js实现创建删除html元素小结
2015/09/30 Javascript
微信小程序中用WebStorm使用LESS
2017/03/08 Javascript
Javascript之图片的延迟加载的实例详解
2017/07/24 Javascript
详解在React.js中使用PureComponent的重要性和使用方式
2018/07/10 Javascript
对layer弹出框中icon数字参数的说明介绍
2019/09/04 Javascript
Vue列表如何实现滚动到指定位置样式改变效果
2020/05/09 Javascript
python正则表达式修复网站文章字体不统一的解决方法
2013/02/21 Python
python实现Floyd算法
2018/01/03 Python
python爬虫 使用真实浏览器打开网页的两种方法总结
2018/04/21 Python
python创建文件备份的脚本
2018/09/11 Python
一篇文章搞定Python操作文件与目录
2019/08/13 Python
浅谈Python描述数据结构之KMP篇
2020/09/06 Python
VSCode中autopep8无法运行问题解决方案(提示Error: Command failed,usage)
2021/03/02 Python
使用CSS3来绘制一个月食图案
2015/07/18 HTML / CSS
菲律宾票务网站:StubHub菲律宾
2018/04/21 全球购物
个人实用简单的自我评价
2013/10/19 职场文书
企业为何需要商业计划书
2013/12/26 职场文书
档案接收函范文
2014/01/10 职场文书
财务部经理岗位职责
2014/02/03 职场文书
交通文明倡议书
2014/05/16 职场文书
理想演讲稿范文
2014/05/21 职场文书
如何写股份合作协议书
2014/09/11 职场文书
2015年元旦演讲稿
2014/09/12 职场文书
股东出资证明书(正规版)
2014/09/24 职场文书
捐资助学感谢信
2015/01/21 职场文书
无房证明样本
2015/06/17 职场文书
2016年“我们的节日·清明节”活动总结
2016/04/01 职场文书