PyTorch基本数据类型(一)


Posted in Python onMay 22, 2019

PyTorch基础入门一:PyTorch基本数据类型

1)Tensor(张量)

Pytorch里面处理的最基本的操作对象就是Tensor(张量),它表示的其实就是一个多维矩阵,并有矩阵相关的运算操作。在使用上和numpy是对应的,它和numpy唯一的不同就是,pytorch可以在GPU上运行,而numpy不可以。所以,我们也可以使用Tensor来代替numpy的使用。当然,二者也可以相互转换。

Tensor的基本数据类型有五种:

  • 32位浮点型:torch.FloatTensor。pyorch.Tensor()默认的就是这种类型。
  • 64位整型:torch.LongTensor。
  • 32位整型:torch.IntTensor。
  • 16位整型:torch.ShortTensor。
  • 64位浮点型:torch.DoubleTensor。

那么如何定义Tensor张量呢?其实定义的方式和numpy一样,直接传入相应的矩阵即可即可。下面就定义了一个三行两列的矩阵:

import torch
# 导包
 
a = torch.Tensor([[1, 2], [3, 4], [5, 6]])
print(a)

不过在项目之中,更多的做法是以特殊值或者随机值初始化一个矩阵,就像下面这样:

import torch
 
# 定义一个3行2列的全为0的矩阵
b = torch.zeros((3, 2))
 
# 定义一个3行2列的随机值矩阵
c = torch.randn((3, 2))
 
# 定义一个3行2列全为1的矩阵
d = torch.ones((3, 2))
 
print(b)
print(c)
print(d)

Tensor和numpy.ndarray之间还可以相互转换,其方式如下:

  • Numpy转化为Tensor:torch.from_numpy(numpy矩阵)
  • Tensor转化为numpy:Tensor矩阵.numpy()

范例如下:

import torch
import numpy as np
 
# 定义一个3行2列的全为0的矩阵
b = torch.randn((3, 2))
 
# tensor转化为numpy
numpy_b = b.numpy()
print(numpy_b)
 
# numpy转化为tensor
numpy_e = np.array([[1, 2], [3, 4], [5, 6]])
torch_e = torch.from_numpy(numpy_e)
 
print(numpy_e)
print(torch_e)

之前说过,numpy与Tensor最大的区别就是在对GPU的支持上。Tensor只需要调用cuda()函数就可以将其转化为能在GPU上运行的类型。

我们可以通过torch.cuda.is_available()函数来判断当前的环境是否支持GPU,如果支持,则返回True。所以,为保险起见,在项目代码中一般采取“先判断,后使用”的策略来保证代码的正常运行,其基本结构如下:

import torch
 
# 定义一个3行2列的全为0的矩阵
tmp = torch.randn((3, 2))
 
# 如果支持GPU,则定义为GPU类型
if torch.cuda.is_available():
  inputs = tmp.cuda()
# 否则,定义为一般的Tensor类型
else:
  inputs = tmp

2)Variable(变量)

Pytorch里面的Variable类型数据功能更加强大,相当于是在Tensor外层套了一个壳子,这个壳子赋予了前向传播,反向传播,自动求导等功能,在计算图的构建中起的很重要的作用。Variable的结构图如下:

PyTorch基本数据类型(一)

其中最重要的两个属性是:data和grad。Data表示该变量保存的实际数据,通过该属性可以访问到它所保存的原始张量类型,而关于该 variable(变量)的梯度会被累计到.grad 上去。

在使用Variable的时候需要从torch.autograd中导入。下面通过一个例子来看一下它自动求导的过程:

import torch
from torch.autograd import Variable
 
# 定义三个Variable变量
x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
w = Variable(torch.Tensor([2, 3, 4]), requires_grad=True)
b = Variable(torch.Tensor([3, 4, 5]), requires_grad=True)
 
# 构建计算图,公式为:y = w * x^2 + b
y = w * x * x + b
 
# 自动求导,计算梯度
y.backward(torch.Tensor([1, 1, 1]))
 
print(x.grad)
print(w.grad)
print(b.grad)

上述代码的计算图为y = w * x^2 + b。对x, w, b分别求偏导为:x.grad = 2wx,w.grad=x^2,b.grad=1。代值检验可得计算结果是正确的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列表推导式与生成器表达式用法示例
Feb 08 Python
PyQt5每天必学之日历控件QCalendarWidget
Apr 19 Python
python统计多维数组的行数和列数实例
Jun 23 Python
Python基于多线程操作数据库相关问题分析
Jul 11 Python
Python里字典的基本用法(包括嵌套字典)
Feb 27 Python
Python 转换文本编码实现解析
Aug 27 Python
Python3的unicode编码转换成中文的问题及解决方案
Dec 10 Python
查看keras的默认backend实现方式
Jun 19 Python
Python利用Pillow(PIL)库实现验证码图片的全过程
Oct 04 Python
如何用用Python将地址标记在地图上
Feb 07 Python
Python竟然能剪辑视频
May 25 Python
python plt.plot bar 如何设置绘图尺寸大小
Jun 01 Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
You might like
基于mysql的论坛(4)
2006/10/09 PHP
如何获知PHP程序占用多少内存(memory_get_usage)
2012/09/23 PHP
php的ZipArchive类用法实例
2014/10/20 PHP
PHP中Trait及其应用详解
2017/02/14 PHP
Yii2 队列 shmilyzxt/yii2-queue 简单概述
2017/08/02 PHP
js jquery做的图片连续滚动代码
2008/01/06 Javascript
javascript 图片上一张下一张链接效果代码
2010/03/12 Javascript
jquery.fileEveryWhere.js 一个跨浏览器的file显示插件
2011/10/24 Javascript
JavaScript中创建对象和继承示例解读
2014/02/12 Javascript
js的延迟执行问题分析
2014/06/23 Javascript
node.js解决获取图片真实文件类型的问题
2014/12/20 Javascript
JavaScript中用sort()方法对数组元素进行排序的操作
2015/06/09 Javascript
百度地图API之本地搜索与范围搜索
2015/07/30 Javascript
jQuery实现彩带延伸效果的网页加载条loading动画
2015/10/29 Javascript
浅谈js中的引用和复制(传值和传址)
2016/09/18 Javascript
JavaScript-定时器0~9抽奖系统详解(代码)
2017/08/16 Javascript
js使用xml数据载体实现城市省份二级联动效果
2017/11/08 Javascript
解决Angular.js中使用Swiper插件不能滑动的问题
2018/02/26 Javascript
JS中超越现实的匿名函数用法实例分析
2019/06/21 Javascript
微信小程序 授权登录详解(附完整源码)
2019/08/23 Javascript
token 机制和实现方式
2020/12/15 Javascript
python3中bytes和string之间的互相转换
2017/02/09 Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
2020/01/14 Python
sklearn的predict_proba使用说明
2020/06/28 Python
pycharm永久激活超详细教程
2020/10/29 Python
如何基于Python pygame实现动画跑马灯
2020/11/18 Python
python RSA加密的示例
2020/12/09 Python
虚拟环境及venv和virtualenv的区别说明
2021/02/05 Python
德国机车企业:FC-Moto
2017/10/27 全球购物
中学生学习生活的自我评价
2013/10/26 职场文书
2014年迎新年联欢会活动策划方案
2014/02/26 职场文书
希特勒经典演讲稿
2014/05/19 职场文书
医药公司开票员岗位职责
2015/04/15 职场文书
看雷锋电影观后感
2015/06/10 职场文书
企业开发CSS命名BEM代码规范实践
2022/02/12 HTML / CSS
详解MySQL的内连接和外连接
2023/05/08 MySQL