pytorch随机采样操作SubsetRandomSampler()


Posted in Python onJuly 07, 2020

这篇文章记录一个采样器都随机地从原始的数据集中抽样数据。抽样数据采用permutation。 生成任意一个下标重排,从而利用下标来提取dataset中的数据的方法

需要的库

import torch

使用方法

这里以MNIST举例

train_dataset = dsets.MNIST(root='./data', #文件存放路径
              train=True,  #提取训练集
              transform=transforms.ToTensor(), #将图像转化为Tensor
              download=True)

sample_size = len(train_dataset)
sampler1 = torch.utils.data.sampler.SubsetRandomSampler(
  np.random.choice(range(len(train_dataset)), sample_size))

代码详解

np.random.choice()

#numpy.random.choice(a, size=None, replace=True, p=None)
#从a(只要是ndarray都可以,但必须是一维的)中随机抽取数字,并组成指定大小(size)的数组
#replace:True表示可以取相同数字,False表示不可以取相同数字
#数组p:与数组a相对应,表示取数组a中每个元素的概率,默认为选取每个元素的概率相同。

那么这里就相当于抽取了一个全排列

torch.utils.data.sampler.SubsetRandomSampler

# 会根据后面给的列表从数据集中按照下标取元素
# class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

所以就可以了。

补充知识:Pytorch学习之torch----随机抽样、序列化、并行化

1. torch.manual_seed(seed)

说明:设置生成随机数的种子,返回一个torch._C.Generator对象。使用随机数种子之后,生成的随机数是相同的。

参数:

seed(int or long) -- 种子

>>> import torch
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> a = torch.rand(2, 3)
>>> a
tensor([[0.7576, 0.2793, 0.4031],
    [0.7347, 0.0293, 0.7999]])
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> b = torch.rand(2, 3)
>>> b
tensor([[0.7576, 0.2793, 0.4031],
    [0.7347, 0.0293, 0.7999]])
>>> a == b
tensor([[1, 1, 1],
    [1, 1, 1]], dtype=torch.uint8)

2. torch.initial_seed()

说明:返回生成随机数的原始种子值

>>> torch.manual_seed(4)
<torch._C.Generator object at 0x0000019684586350>
>>> torch.initial_seed()
4

3. torch.get_rng_state()

说明:返回随机生成器状态(ByteTensor)

>>> torch.initial_seed()
4
>>> torch.get_rng_state()
tensor([4, 0, 0, ..., 0, 0, 0], dtype=torch.uint8)

4. torch.set_rng_state()

说明:设定随机生成器状态

参数:

new_state(ByteTensor) -- 期望的状态

5. torch.default_generator

说明:默认的随机生成器。等于<torch._C.Generator object>

6. torch.bernoulli(input, out=None)

说明:从伯努利分布中抽取二元随机数(0或1)。输入张量包含用于抽取二元值的概率。因此,输入中的所有值都必须在[0,1]区间内。输出张量的第i个元素值,将会以输入张量的第i个概率值等于1。返回值将会是与输入相同大小的张量,每个值为0或者1.

参数:

input(Tensor) -- 输入为伯努利分布的概率值

out(Tensor,可选) -- 输出张量

>>> a = torch.Tensor(3, 3).uniform_(0, 1)
>>> a
tensor([[0.5596, 0.5591, 0.0915],
    [0.2100, 0.0072, 0.0390],
    [0.9929, 0.9131, 0.6186]])
>>> torch.bernoulli(a)
tensor([[0., 1., 0.],
    [0., 0., 0.],
    [1., 1., 1.]])

7. torch.multinomial(input, num_samples, replacement=False, out=None)

说明:返回一个张量,每行包含从input相应行中定义的多项分布中抽取的num_samples个样本。要求输入input每行的值不需要总和为1,但是必须非负且总和不能为0。当抽取样本时,依次从左到右排列(第一个样本对应第一列)。如果输入input是一个向量,输出out也是一个相同长度num_samples的向量。如果输入input是m行的矩阵,输出out是形如m x n的矩阵。并且如果参数replacement为True,则样本抽取可以重复。否则,一个样本在每行不能被重复。

参数:

input(Tensor) -- 包含概率的张量

num_samples(int) -- 抽取的样本数

replacement(bool) -- 布尔值,决定是否能重复抽取

out(Tensor) -- 结果张量

>>> weights = torch.Tensor([0, 10, 3, 0])
>>> weights
tensor([ 0., 10., 3., 0.])
>>> torch.multinomial(weights, 4, replacement=True)
tensor([1, 1, 1, 1])

8. torch.normal(means, std, out=None)

说明:返回一个张量,包含从给定参数means,std的离散正态分布中抽取随机数。均值means是一个张量,包含每个输出元素相关的正态分布的均值。std是一个张量。包含每个输出元素相关的正态分布的标准差。均值和标准差的形状不须匹配,但每个张量的元素个数必须想听。

参数:

means(Tensor) -- 均值

std(Tensor) -- 标准差

out(Tensor) -- 输出张量

>>> n_data = torch.ones(5, 2)
>>> n_data
tensor([[1., 1.],
    [1., 1.],
    [1., 1.],
    [1., 1.],
    [1., 1.]])
>>> x0 = torch.normal(2 * n_data, 1)
>>> x0
tensor([[1.6544, 0.9805],
    [2.1114, 2.7113],
    [1.0646, 1.9675],
    [2.7652, 3.2138],
    [1.1204, 2.0293]])

9. torch.save(obj, f, pickle_module=<module 'pickle' from '/home/lzjs/...)

说明:保存一个对象到一个硬盘文件上。

参数:

obj -- 保存对象

f -- 类文件对象或一个保存文件名的字符串

pickle_module -- 用于pickling源数据和对象的模块

pickle_protocol -- 指定pickle protocal可以覆盖默认参数

10. torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/home/lzjs/...)

说明:从磁盘文件中读取一个通过torch.save()保存的对象。torch.load()可通过参数map_location动态地进行内存重映射,使其能从不动设备中读取文件。一般调用时,需两个参数:storage和location tag。返回不同地址中的storage,或者返回None。如果这个参数是字典的话,意味着从文件的地址标记到当前系统的地址标记的映射。

参数:

f -- l类文件对象或一个保存文件名的字符串

map_location -- 一个函数或字典规定如何remap存储位置

pickle_module -- 用于unpickling元数据和对象的模块

torch.load('tensors.pt')
# 加载所有的张量到CPU
torch.load('tensor.pt', map_location=lambda storage, loc:storage)
# 加载张量到GPU
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

11. torch.get_num_threads()

说明:获得用于并行化CPU操作的OpenMP线程数

12. torch.set_num_threads()

说明:设定用于并行化CPU操作的OpenMP线程数

以上这篇pytorch随机采样操作SubsetRandomSampler()就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在arcgis使用python脚本进行字段计算时是如何解决中文问题的
Oct 18 Python
如何使用python爬取csdn博客访问量
Feb 14 Python
利用python微信库itchat实现微信自动回复功能
May 18 Python
python机器学习实战之树回归详解
Dec 20 Python
python pandas 如何替换某列的一个值
Jun 09 Python
python requests更换代理适用于IP频率限制的方法
Aug 21 Python
浅析python表达式4+0.5值的数据类型
Feb 26 Python
Python中remove漏删和索引越界问题的解决
Mar 18 Python
python实现3D地图可视化
Mar 25 Python
Python基于jieba, wordcloud库生成中文词云
May 13 Python
Python SMTP配置参数并发送邮件
Jun 16 Python
python在package下继续嵌套一个package
Apr 14 Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
keras实现VGG16方式(预测一张图片)
Jul 07 #Python
通过实例解析Python RPC实现原理及方法
Jul 07 #Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 #Python
You might like
wamp下修改mysql访问密码的解决方法
2013/05/07 PHP
php处理单文件、多文件上传代码分享
2016/08/24 PHP
js form action动态修改方法
2008/11/04 Javascript
JavaScript 以对象为索引的关联数组
2010/05/19 Javascript
Javascript实现滑块滑动改变值的实现代码
2013/04/12 Javascript
Javascript实现重力弹跳拖拽运动效果示例
2013/06/28 Javascript
jQuery的DOM操作之删除节点示例
2014/01/03 Javascript
jQuery中last()方法用法实例
2015/01/06 Javascript
JavaScript中实现依赖注入的思路分享
2015/01/15 Javascript
DOM操作一些常用的属性汇总
2015/03/13 Javascript
js图片翻书效果代码分享
2015/08/20 Javascript
基于canvas实现的绚丽圆圈效果完整实例
2016/01/26 Javascript
jquery实现简单的banner轮播效果【实例】
2016/03/30 Javascript
JavaScript知识点总结(十)之this关键字
2016/05/31 Javascript
JavaScript学习笔记整理_简单实现枚举类型,扑克牌应用
2016/09/19 Javascript
bootstrap配合Masonry插件实现瀑布式布局
2017/01/18 Javascript
BootStrap中jQuery插件Carousel实现轮播广告效果
2017/03/27 jQuery
hammer.js实现图片手势放大效果
2017/08/29 Javascript
jquery实现回车键触发事件(实例讲解)
2017/11/21 jQuery
AngularJS实现的锚点楼层跳转功能示例
2018/01/02 Javascript
详解使用 Node.js 开发简单的脚手架工具
2018/06/08 Javascript
浅谈Vue SSR中的Bundle的具有使用
2019/11/21 Javascript
微信小程序保存图片到相册权限设置
2020/04/09 Javascript
探索node之事件循环的实现
2020/10/30 Javascript
[46:03]LGD vs VGJ.T 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python实现机器学习之多元线性回归
2018/09/06 Python
PyCharm更改字体和界面样式的方法步骤
2019/09/27 Python
如何查看浏览器对html5的支持情况
2020/12/15 HTML / CSS
全球地下的服装和态度:Slam Jam
2018/02/04 全球购物
中秋节礼品促销方案
2014/02/02 职场文书
红旗渠导游词
2015/02/09 职场文书
中学生国庆节演讲稿2015
2015/07/30 职场文书
五星级酒店宣传口号
2015/12/25 职场文书
幼师自荐信范文(2016推荐篇)
2016/01/28 职场文书
源码解读Spring-Integration执行过程
2021/06/11 Java/Android
ICOM R71E和R72E图文对比解说
2022/04/07 无线电