pytorch随机采样操作SubsetRandomSampler()


Posted in Python onJuly 07, 2020

这篇文章记录一个采样器都随机地从原始的数据集中抽样数据。抽样数据采用permutation。 生成任意一个下标重排,从而利用下标来提取dataset中的数据的方法

需要的库

import torch

使用方法

这里以MNIST举例

train_dataset = dsets.MNIST(root='./data', #文件存放路径
              train=True,  #提取训练集
              transform=transforms.ToTensor(), #将图像转化为Tensor
              download=True)

sample_size = len(train_dataset)
sampler1 = torch.utils.data.sampler.SubsetRandomSampler(
  np.random.choice(range(len(train_dataset)), sample_size))

代码详解

np.random.choice()

#numpy.random.choice(a, size=None, replace=True, p=None)
#从a(只要是ndarray都可以,但必须是一维的)中随机抽取数字,并组成指定大小(size)的数组
#replace:True表示可以取相同数字,False表示不可以取相同数字
#数组p:与数组a相对应,表示取数组a中每个元素的概率,默认为选取每个元素的概率相同。

那么这里就相当于抽取了一个全排列

torch.utils.data.sampler.SubsetRandomSampler

# 会根据后面给的列表从数据集中按照下标取元素
# class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

所以就可以了。

补充知识:Pytorch学习之torch----随机抽样、序列化、并行化

1. torch.manual_seed(seed)

说明:设置生成随机数的种子,返回一个torch._C.Generator对象。使用随机数种子之后,生成的随机数是相同的。

参数:

seed(int or long) -- 种子

>>> import torch
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> a = torch.rand(2, 3)
>>> a
tensor([[0.7576, 0.2793, 0.4031],
    [0.7347, 0.0293, 0.7999]])
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> b = torch.rand(2, 3)
>>> b
tensor([[0.7576, 0.2793, 0.4031],
    [0.7347, 0.0293, 0.7999]])
>>> a == b
tensor([[1, 1, 1],
    [1, 1, 1]], dtype=torch.uint8)

2. torch.initial_seed()

说明:返回生成随机数的原始种子值

>>> torch.manual_seed(4)
<torch._C.Generator object at 0x0000019684586350>
>>> torch.initial_seed()
4

3. torch.get_rng_state()

说明:返回随机生成器状态(ByteTensor)

>>> torch.initial_seed()
4
>>> torch.get_rng_state()
tensor([4, 0, 0, ..., 0, 0, 0], dtype=torch.uint8)

4. torch.set_rng_state()

说明:设定随机生成器状态

参数:

new_state(ByteTensor) -- 期望的状态

5. torch.default_generator

说明:默认的随机生成器。等于<torch._C.Generator object>

6. torch.bernoulli(input, out=None)

说明:从伯努利分布中抽取二元随机数(0或1)。输入张量包含用于抽取二元值的概率。因此,输入中的所有值都必须在[0,1]区间内。输出张量的第i个元素值,将会以输入张量的第i个概率值等于1。返回值将会是与输入相同大小的张量,每个值为0或者1.

参数:

input(Tensor) -- 输入为伯努利分布的概率值

out(Tensor,可选) -- 输出张量

>>> a = torch.Tensor(3, 3).uniform_(0, 1)
>>> a
tensor([[0.5596, 0.5591, 0.0915],
    [0.2100, 0.0072, 0.0390],
    [0.9929, 0.9131, 0.6186]])
>>> torch.bernoulli(a)
tensor([[0., 1., 0.],
    [0., 0., 0.],
    [1., 1., 1.]])

7. torch.multinomial(input, num_samples, replacement=False, out=None)

说明:返回一个张量,每行包含从input相应行中定义的多项分布中抽取的num_samples个样本。要求输入input每行的值不需要总和为1,但是必须非负且总和不能为0。当抽取样本时,依次从左到右排列(第一个样本对应第一列)。如果输入input是一个向量,输出out也是一个相同长度num_samples的向量。如果输入input是m行的矩阵,输出out是形如m x n的矩阵。并且如果参数replacement为True,则样本抽取可以重复。否则,一个样本在每行不能被重复。

参数:

input(Tensor) -- 包含概率的张量

num_samples(int) -- 抽取的样本数

replacement(bool) -- 布尔值,决定是否能重复抽取

out(Tensor) -- 结果张量

>>> weights = torch.Tensor([0, 10, 3, 0])
>>> weights
tensor([ 0., 10., 3., 0.])
>>> torch.multinomial(weights, 4, replacement=True)
tensor([1, 1, 1, 1])

8. torch.normal(means, std, out=None)

说明:返回一个张量,包含从给定参数means,std的离散正态分布中抽取随机数。均值means是一个张量,包含每个输出元素相关的正态分布的均值。std是一个张量。包含每个输出元素相关的正态分布的标准差。均值和标准差的形状不须匹配,但每个张量的元素个数必须想听。

参数:

means(Tensor) -- 均值

std(Tensor) -- 标准差

out(Tensor) -- 输出张量

>>> n_data = torch.ones(5, 2)
>>> n_data
tensor([[1., 1.],
    [1., 1.],
    [1., 1.],
    [1., 1.],
    [1., 1.]])
>>> x0 = torch.normal(2 * n_data, 1)
>>> x0
tensor([[1.6544, 0.9805],
    [2.1114, 2.7113],
    [1.0646, 1.9675],
    [2.7652, 3.2138],
    [1.1204, 2.0293]])

9. torch.save(obj, f, pickle_module=<module 'pickle' from '/home/lzjs/...)

说明:保存一个对象到一个硬盘文件上。

参数:

obj -- 保存对象

f -- 类文件对象或一个保存文件名的字符串

pickle_module -- 用于pickling源数据和对象的模块

pickle_protocol -- 指定pickle protocal可以覆盖默认参数

10. torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/home/lzjs/...)

说明:从磁盘文件中读取一个通过torch.save()保存的对象。torch.load()可通过参数map_location动态地进行内存重映射,使其能从不动设备中读取文件。一般调用时,需两个参数:storage和location tag。返回不同地址中的storage,或者返回None。如果这个参数是字典的话,意味着从文件的地址标记到当前系统的地址标记的映射。

参数:

f -- l类文件对象或一个保存文件名的字符串

map_location -- 一个函数或字典规定如何remap存储位置

pickle_module -- 用于unpickling元数据和对象的模块

torch.load('tensors.pt')
# 加载所有的张量到CPU
torch.load('tensor.pt', map_location=lambda storage, loc:storage)
# 加载张量到GPU
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

11. torch.get_num_threads()

说明:获得用于并行化CPU操作的OpenMP线程数

12. torch.set_num_threads()

说明:设定用于并行化CPU操作的OpenMP线程数

以上这篇pytorch随机采样操作SubsetRandomSampler()就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现模拟按键,自动翻页看u17漫画
Mar 17 Python
在Mac OS系统上安装Python的Pillow库的教程
Nov 20 Python
Python基于高斯消元法计算线性方程组示例
Jan 17 Python
对python插入数据库和生成插入sql的示例讲解
Nov 14 Python
python获取本机所有IP地址的方法
Dec 26 Python
详解Python3 基本数据类型
Apr 19 Python
Python 中Django验证码功能的实现代码
Jun 20 Python
django框架面向对象ORM模型继承用法实例分析
Jul 29 Python
浅析PyTorch中nn.Module的使用
Aug 18 Python
python操作openpyxl导出Excel 设置单元格格式及合并处理代码实例
Aug 27 Python
python连接PostgreSQL过程解析
Feb 09 Python
如何清空python的变量
Jul 05 Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
keras实现VGG16方式(预测一张图片)
Jul 07 #Python
通过实例解析Python RPC实现原理及方法
Jul 07 #Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 #Python
You might like
php如何连接sql server
2015/10/16 PHP
使用ltrace工具跟踪PHP库函数调用的方法
2016/04/25 PHP
Thinkphp框架 表单自动验证登录注册 ajax自动验证登录注册
2016/12/27 PHP
PHP函数用法详解【初始化、嵌套、内置函数等】
2020/06/02 PHP
Ucren Virtual Desktop V2.0
2006/11/07 Javascript
再谈ie和firefox下的document.all属性
2009/10/21 Javascript
javascript 导出数据到Excel(处理table中的元素)
2009/12/18 Javascript
一个关于javascript匿名函数的问题分析
2012/03/30 Javascript
ExtJS中文乱码之GBK格式编码解决方案及代码
2013/01/20 Javascript
Jquery和JS用外部变量获取Ajax返回的参数值的方法实例(超简单)
2013/06/17 Javascript
JS实现超简洁网页title标题跑动闪烁提示效果代码
2015/10/23 Javascript
JS防止网页被嵌入iframe框架的方法分析
2016/09/13 Javascript
SelectPage v2.4 发布新增纯下拉列表和关闭分页功能
2017/09/07 Javascript
js实现上传并压缩图片效果
2018/01/10 Javascript
vue中的router-view组件的使用教程
2018/10/23 Javascript
vue地址栏直接输入路由无效问题的解决
2018/11/15 Javascript
详解如何在vue项目中使用layui框架及采坑
2019/05/05 Javascript
laypage.js分页插件使用方法详解
2019/07/27 Javascript
微信小程序使用echarts获取数据并生成折线图
2019/10/16 Javascript
图文详解WinPE下安装Python
2016/05/17 Python
利用Python批量压缩png方法实例(支持过滤个别文件与文件夹)
2017/07/30 Python
详解python使用递归、尾递归、循环三种方式实现斐波那契数列
2018/01/16 Python
Window环境下Scrapy开发环境搭建
2018/11/18 Python
python实现弹跳小球
2019/05/13 Python
python清空命令行方式
2020/01/13 Python
python构造IP报文实例
2020/05/05 Python
python和go语言的区别是什么
2020/07/20 Python
HTML5 Canvas中使用用路径描画圆弧
2015/01/01 HTML / CSS
药物学专业学生的自我评价
2013/10/27 职场文书
自荐信格式技巧有哪些呢
2013/11/19 职场文书
工作失误检讨书范文大全
2014/01/13 职场文书
食堂个人先进事迹
2014/01/22 职场文书
个人查摆问题自查报告
2014/10/16 职场文书
高中升旗仪式主持词
2015/07/03 职场文书
施工安全协议书
2016/03/22 职场文书
Django框架中模型的用法
2022/06/10 Python