关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用三角迭代计算圆周率PI的方法
Mar 20 Python
python读写ini配置文件方法实例分析
Jun 30 Python
基于python 二维数组及画图的实例详解
Apr 03 Python
python topN 取最大的N个数或最小的N个数方法
Jun 04 Python
Python实现的字典排序操作示例【按键名key与键值value排序】
Dec 21 Python
python 计算平均平方误差(MSE)的实例
Jun 29 Python
Python求均值,方差,标准差的实例
Jun 29 Python
python3安装crypto出错及解决方法
Jul 30 Python
利用python-docx模块写批量生日邀请函
Aug 26 Python
TensorBoard 计算图的查看方式
Feb 15 Python
Python模拟登入的N种方式(建议收藏)
May 31 Python
使用python实现时间序列白噪声检验方式
Jun 03 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
PHP中的多行字符串传递给JavaScript的两种方法
2014/06/19 PHP
windows7下php开发环境搭建图文教程
2015/01/06 PHP
开启PHP Static 关键字之旅模式
2015/11/13 PHP
js 文本滚动效果的实例代码
2013/08/17 Javascript
深入理解javascript动态插入技术
2013/11/12 Javascript
jQuery级联操作绑定事件实例
2014/09/02 Javascript
js实现用户注册协议倒计时的方法
2015/01/21 Javascript
JavaScript中document.forms[0]与getElementByName区别
2015/01/21 Javascript
Javascript核心读书有感之表达式和运算符
2015/02/11 Javascript
window.location.reload 刷新使用分析(去对话框)
2015/11/11 Javascript
JQueryEasyUI之DataGrid数据显示
2016/11/23 Javascript
js获取地址栏中传递的参数(两种方法)
2017/02/08 Javascript
jQuery使用EasyUi实现三级联动下拉框效果
2017/03/08 Javascript
vue使用监听实现全选反选功能
2018/07/06 Javascript
用Cordova打包Vue项目的方法步骤
2019/02/02 Javascript
JS实现transform实现扇子效果
2020/01/17 Javascript
python使用nntp读取新闻组内容的方法
2015/05/08 Python
浅谈python新手中常见的疑惑及解答
2016/06/14 Python
利用Python查看目录中的文件示例详解
2017/08/28 Python
浅谈python jieba分词模块的基本用法
2017/11/09 Python
Python面向对象类的继承实例详解
2018/06/27 Python
Python使用sorted对字典的key或value排序
2018/11/15 Python
Face++ API实现手势识别系统设计
2018/11/21 Python
Python Django模板之模板过滤器与自定义模板过滤器示例
2019/10/18 Python
Pytorch实现的手写数字mnist识别功能完整示例
2019/12/13 Python
解决pycharm不能自动补全第三方库的函数和属性问题
2020/03/12 Python
护理自荐信
2013/10/22 职场文书
新郎婚宴答谢词
2014/01/19 职场文书
2014庆六一活动方案
2014/03/02 职场文书
终止合同协议书
2014/04/17 职场文书
食品安全工作方案
2014/05/07 职场文书
对外汉语专业大学生职业生涯规划范文
2014/09/13 职场文书
大学新生军训自我鉴定
2014/09/18 职场文书
表扬信范文
2015/05/04 职场文书
《乘法分配律》教学反思
2016/02/24 职场文书
为什么mysql字段要使用NOT NULL
2021/05/13 MySQL