pytorch随机采样操作SubsetRandomSampler()


Posted in Python onJuly 07, 2020

这篇文章记录一个采样器都随机地从原始的数据集中抽样数据。抽样数据采用permutation。 生成任意一个下标重排,从而利用下标来提取dataset中的数据的方法

需要的库

import torch

使用方法

这里以MNIST举例

train_dataset = dsets.MNIST(root='./data', #文件存放路径
              train=True,  #提取训练集
              transform=transforms.ToTensor(), #将图像转化为Tensor
              download=True)

sample_size = len(train_dataset)
sampler1 = torch.utils.data.sampler.SubsetRandomSampler(
  np.random.choice(range(len(train_dataset)), sample_size))

代码详解

np.random.choice()

#numpy.random.choice(a, size=None, replace=True, p=None)
#从a(只要是ndarray都可以,但必须是一维的)中随机抽取数字,并组成指定大小(size)的数组
#replace:True表示可以取相同数字,False表示不可以取相同数字
#数组p:与数组a相对应,表示取数组a中每个元素的概率,默认为选取每个元素的概率相同。

那么这里就相当于抽取了一个全排列

torch.utils.data.sampler.SubsetRandomSampler

# 会根据后面给的列表从数据集中按照下标取元素
# class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。

所以就可以了。

补充知识:Pytorch学习之torch----随机抽样、序列化、并行化

1. torch.manual_seed(seed)

说明:设置生成随机数的种子,返回一个torch._C.Generator对象。使用随机数种子之后,生成的随机数是相同的。

参数:

seed(int or long) -- 种子

>>> import torch
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> a = torch.rand(2, 3)
>>> a
tensor([[0.7576, 0.2793, 0.4031],
    [0.7347, 0.0293, 0.7999]])
>>> torch.manual_seed(1)
<torch._C.Generator object at 0x0000019684586350>
>>> b = torch.rand(2, 3)
>>> b
tensor([[0.7576, 0.2793, 0.4031],
    [0.7347, 0.0293, 0.7999]])
>>> a == b
tensor([[1, 1, 1],
    [1, 1, 1]], dtype=torch.uint8)

2. torch.initial_seed()

说明:返回生成随机数的原始种子值

>>> torch.manual_seed(4)
<torch._C.Generator object at 0x0000019684586350>
>>> torch.initial_seed()
4

3. torch.get_rng_state()

说明:返回随机生成器状态(ByteTensor)

>>> torch.initial_seed()
4
>>> torch.get_rng_state()
tensor([4, 0, 0, ..., 0, 0, 0], dtype=torch.uint8)

4. torch.set_rng_state()

说明:设定随机生成器状态

参数:

new_state(ByteTensor) -- 期望的状态

5. torch.default_generator

说明:默认的随机生成器。等于<torch._C.Generator object>

6. torch.bernoulli(input, out=None)

说明:从伯努利分布中抽取二元随机数(0或1)。输入张量包含用于抽取二元值的概率。因此,输入中的所有值都必须在[0,1]区间内。输出张量的第i个元素值,将会以输入张量的第i个概率值等于1。返回值将会是与输入相同大小的张量,每个值为0或者1.

参数:

input(Tensor) -- 输入为伯努利分布的概率值

out(Tensor,可选) -- 输出张量

>>> a = torch.Tensor(3, 3).uniform_(0, 1)
>>> a
tensor([[0.5596, 0.5591, 0.0915],
    [0.2100, 0.0072, 0.0390],
    [0.9929, 0.9131, 0.6186]])
>>> torch.bernoulli(a)
tensor([[0., 1., 0.],
    [0., 0., 0.],
    [1., 1., 1.]])

7. torch.multinomial(input, num_samples, replacement=False, out=None)

说明:返回一个张量,每行包含从input相应行中定义的多项分布中抽取的num_samples个样本。要求输入input每行的值不需要总和为1,但是必须非负且总和不能为0。当抽取样本时,依次从左到右排列(第一个样本对应第一列)。如果输入input是一个向量,输出out也是一个相同长度num_samples的向量。如果输入input是m行的矩阵,输出out是形如m x n的矩阵。并且如果参数replacement为True,则样本抽取可以重复。否则,一个样本在每行不能被重复。

参数:

input(Tensor) -- 包含概率的张量

num_samples(int) -- 抽取的样本数

replacement(bool) -- 布尔值,决定是否能重复抽取

out(Tensor) -- 结果张量

>>> weights = torch.Tensor([0, 10, 3, 0])
>>> weights
tensor([ 0., 10., 3., 0.])
>>> torch.multinomial(weights, 4, replacement=True)
tensor([1, 1, 1, 1])

8. torch.normal(means, std, out=None)

说明:返回一个张量,包含从给定参数means,std的离散正态分布中抽取随机数。均值means是一个张量,包含每个输出元素相关的正态分布的均值。std是一个张量。包含每个输出元素相关的正态分布的标准差。均值和标准差的形状不须匹配,但每个张量的元素个数必须想听。

参数:

means(Tensor) -- 均值

std(Tensor) -- 标准差

out(Tensor) -- 输出张量

>>> n_data = torch.ones(5, 2)
>>> n_data
tensor([[1., 1.],
    [1., 1.],
    [1., 1.],
    [1., 1.],
    [1., 1.]])
>>> x0 = torch.normal(2 * n_data, 1)
>>> x0
tensor([[1.6544, 0.9805],
    [2.1114, 2.7113],
    [1.0646, 1.9675],
    [2.7652, 3.2138],
    [1.1204, 2.0293]])

9. torch.save(obj, f, pickle_module=<module 'pickle' from '/home/lzjs/...)

说明:保存一个对象到一个硬盘文件上。

参数:

obj -- 保存对象

f -- 类文件对象或一个保存文件名的字符串

pickle_module -- 用于pickling源数据和对象的模块

pickle_protocol -- 指定pickle protocal可以覆盖默认参数

10. torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/home/lzjs/...)

说明:从磁盘文件中读取一个通过torch.save()保存的对象。torch.load()可通过参数map_location动态地进行内存重映射,使其能从不动设备中读取文件。一般调用时,需两个参数:storage和location tag。返回不同地址中的storage,或者返回None。如果这个参数是字典的话,意味着从文件的地址标记到当前系统的地址标记的映射。

参数:

f -- l类文件对象或一个保存文件名的字符串

map_location -- 一个函数或字典规定如何remap存储位置

pickle_module -- 用于unpickling元数据和对象的模块

torch.load('tensors.pt')
# 加载所有的张量到CPU
torch.load('tensor.pt', map_location=lambda storage, loc:storage)
# 加载张量到GPU
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

11. torch.get_num_threads()

说明:获得用于并行化CPU操作的OpenMP线程数

12. torch.set_num_threads()

说明:设定用于并行化CPU操作的OpenMP线程数

以上这篇pytorch随机采样操作SubsetRandomSampler()就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
centos下更新Python版本的步骤
Feb 12 Python
简单谈谈python中的多进程
Nov 06 Python
python 爬虫出现403禁止访问错误详解
Mar 11 Python
Python实现对文件进行单词划分并去重排序操作示例
Jul 10 Python
tensorflow -gpu安装方法(不用自己装cuda,cdnn)
Jan 20 Python
基于python求两个列表的并集.交集.差集
Feb 10 Python
Pycharm 安装 idea VIM插件的图文教程详解
Feb 21 Python
jupyter note 实现将数据保存为word
Apr 14 Python
python 操作mysql数据中fetchone()和fetchall()方式
May 15 Python
使用Python实现将多表分批次从数据库导出到Excel
May 15 Python
Ubuntu配置Pytorch on Graph (PoG)环境过程图解
Nov 19 Python
Python 多线程处理任务实例
Nov 07 Python
pytorch加载自己的图像数据集实例
Jul 07 #Python
keras实现VGG16 CIFAR10数据集方式
Jul 07 #Python
使用darknet框架的imagenet数据分类预训练操作
Jul 07 #Python
Python调用C语言程序方法解析
Jul 07 #Python
keras实现VGG16方式(预测一张图片)
Jul 07 #Python
通过实例解析Python RPC实现原理及方法
Jul 07 #Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 #Python
You might like
PHP从FLV文件获取视频预览图的方法
2015/03/12 PHP
jquery 插件 web2.0分格的分页脚本,可用于ajax无刷新分页
2008/12/25 Javascript
基于jquery的让页面控件不可用的实现代码
2010/04/27 Javascript
常用一些Javascript判断函数
2012/08/14 Javascript
JQuery中$.ajax()方法参数详解及应用
2013/12/12 Javascript
如何将php数组或者对象传递给javascript
2014/03/20 Javascript
Javascript 赋值机制详解
2014/11/23 Javascript
jQuery EasyUi实战教程之布局篇
2016/01/26 Javascript
jQuery拖拽排序插件制作拖拽排序效果(附源码下载)
2016/02/23 Javascript
通过js修改input、select默认字体颜色
2017/04/19 Javascript
python爬取安居客二手房网站数据(实例讲解)
2017/10/19 Javascript
基于vue-video-player自定义播放器的方法
2018/03/21 Javascript
JavaScript设计模式之责任链模式实例分析
2019/01/16 Javascript
解决前后端分离 vue+springboot 跨域 session+cookie失效问题
2019/05/13 Javascript
vue-router的两种模式的区别
2019/05/30 Javascript
微信小程序 行的删除和增加操作实现详解
2019/09/29 Javascript
vue项目中监听手机物理返回键的实现
2020/01/18 Javascript
Vue实现指令式动态追加小球动画组件的步骤
2020/12/18 Vue.js
three.js中多线程的使用及性能测试详解
2021/01/07 Javascript
[11:01]2014DOTA2西雅图邀请赛 冷冷带你探秘威斯汀
2014/07/08 DOTA
[08:56]DOTA2-DPC中国联赛2月23日Recap集锦
2021/03/11 DOTA
Python中的super用法详解
2015/05/28 Python
Python设计模式编程中解释器模式的简单程序示例分享
2016/03/02 Python
Python3如何解决字符编码问题详解
2017/04/23 Python
Python 获取当前所在目录的方法详解
2017/08/02 Python
python使用xlrd模块读取xlsx文件中的ip方法
2019/01/11 Python
python实现通过队列完成进程间的多任务功能示例
2019/10/28 Python
详解Tensorflow不同版本要求与CUDA及CUDNN版本对应关系
2020/08/04 Python
使用CSS3编写类似iOS中的复选框及带开关的按钮
2016/04/11 HTML / CSS
来自全球大都市的高级街头服饰:Pegador
2018/01/03 全球购物
人事助理岗位职责
2013/11/18 职场文书
简单租房协议书
2014/04/09 职场文书
优秀求职信
2014/05/29 职场文书
我的中国梦演讲稿500字
2014/08/19 职场文书
2014年个人工作总结范文
2014/11/07 职场文书
环保建议书作文300字
2015/09/14 职场文书