关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现删除当前目录下除当前脚本以外的文件和文件夹实例
Jul 27 Python
apache部署python程序出现503错误的解决方法
Jul 24 Python
Python之自动获取公网IP的实例讲解
Oct 01 Python
Python使用functools实现注解同步方法
Feb 06 Python
Python实现的将文件每一列写入列表功能示例【测试可用】
Mar 19 Python
django+xadmin+djcelery实现后台管理定时任务
Aug 14 Python
Python解决两个整数相除只得到整数部分的实例
Nov 10 Python
Python实现将通信达.day文件读取为DataFrame
Dec 22 Python
Python编程快速上手——选择性拷贝操作案例分析
Feb 28 Python
自定义Django_rest_framework_jwt登陆错误返回的解决
Oct 18 Python
Django配置跨域并开发测试接口
Nov 04 Python
python实现简单猜单词游戏
Dec 24 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
PHP中去掉字符串首尾空格的方法
2012/05/19 PHP
基于PHP生成静态页的实现方法
2013/05/10 PHP
php安装xdebug/php安装pear/phpunit详解步骤(图)
2013/12/22 PHP
php的单例模式及应用场景详解
2021/02/27 PHP
跨浏览器的设置innerHTML方法
2006/09/18 Javascript
滚动经典最新话题[prototype框架]下编写
2006/10/03 Javascript
用示例说明filter()与find()的用法以及children()与find()的区别分析
2013/04/26 Javascript
jQuery中addClass()方法用法实例
2015/01/05 Javascript
jquery表单对象属性过滤选择器实例分析
2015/05/18 Javascript
javascript实现将文件保存到本地方法汇总
2015/07/26 Javascript
jQuery实现鼠标经过弹出提示信息的地图热点效果
2015/08/07 Javascript
jQuery遍历DOM节点操作之filter()方法详解
2016/04/14 Javascript
AngularJS实现的JSONP跨域访问数据传输功能详解
2017/07/20 Javascript
谈谈React中的Render Props模式
2018/12/06 Javascript
详解mpvue实现对苹果X安全区域的适配
2019/07/31 Javascript
Javascript Web Worker使用过程解析
2020/03/16 Javascript
解决VUE mounted 钩子函数执行时 img 未加载导致页面布局的问题
2020/07/27 Javascript
Python中让MySQL查询结果返回字典类型的方法
2014/08/22 Python
简单介绍Python中的struct模块
2015/04/28 Python
python中的for循环
2018/09/28 Python
python提取xml里面的链接源码详解
2019/10/15 Python
Python 爬虫实现增加播客访问量的方法实现
2019/10/31 Python
解决IDEA 的 plugins 搜不到任何的插件问题
2020/05/04 Python
基于python检查矩阵计算结果
2020/05/21 Python
Python如何批量生成和调用变量
2020/11/21 Python
Selenium执行完毕未关闭chromedriver/geckodriver进程的解决办法(java版+python版)
2020/12/07 Python
一款利用纯css3实现的win8加载动画的实例分析
2014/12/11 HTML / CSS
HTML5中原生的右键菜单创建方法
2016/06/28 HTML / CSS
Homestay中文官网:全球寄宿家庭
2018/10/18 全球购物
外国语学院毕业生自荐信
2013/10/28 职场文书
班组长的岗位职责
2013/12/09 职场文书
模具设计与制造专业自荐书
2014/07/01 职场文书
2015年效能监察工作总结
2015/04/23 职场文书
律政俏佳人观后感
2015/06/09 职场文书
python编程实现清理微信重复缓存文件
2021/11/01 Python
vue实现在data里引入相对路径
2022/06/05 Vue.js