PyTorch基本数据类型(一)


Posted in Python onMay 22, 2019

PyTorch基础入门一:PyTorch基本数据类型

1)Tensor(张量)

Pytorch里面处理的最基本的操作对象就是Tensor(张量),它表示的其实就是一个多维矩阵,并有矩阵相关的运算操作。在使用上和numpy是对应的,它和numpy唯一的不同就是,pytorch可以在GPU上运行,而numpy不可以。所以,我们也可以使用Tensor来代替numpy的使用。当然,二者也可以相互转换。

Tensor的基本数据类型有五种:

  • 32位浮点型:torch.FloatTensor。pyorch.Tensor()默认的就是这种类型。
  • 64位整型:torch.LongTensor。
  • 32位整型:torch.IntTensor。
  • 16位整型:torch.ShortTensor。
  • 64位浮点型:torch.DoubleTensor。

那么如何定义Tensor张量呢?其实定义的方式和numpy一样,直接传入相应的矩阵即可即可。下面就定义了一个三行两列的矩阵:

import torch
# 导包
 
a = torch.Tensor([[1, 2], [3, 4], [5, 6]])
print(a)

不过在项目之中,更多的做法是以特殊值或者随机值初始化一个矩阵,就像下面这样:

import torch
 
# 定义一个3行2列的全为0的矩阵
b = torch.zeros((3, 2))
 
# 定义一个3行2列的随机值矩阵
c = torch.randn((3, 2))
 
# 定义一个3行2列全为1的矩阵
d = torch.ones((3, 2))
 
print(b)
print(c)
print(d)

Tensor和numpy.ndarray之间还可以相互转换,其方式如下:

  • Numpy转化为Tensor:torch.from_numpy(numpy矩阵)
  • Tensor转化为numpy:Tensor矩阵.numpy()

范例如下:

import torch
import numpy as np
 
# 定义一个3行2列的全为0的矩阵
b = torch.randn((3, 2))
 
# tensor转化为numpy
numpy_b = b.numpy()
print(numpy_b)
 
# numpy转化为tensor
numpy_e = np.array([[1, 2], [3, 4], [5, 6]])
torch_e = torch.from_numpy(numpy_e)
 
print(numpy_e)
print(torch_e)

之前说过,numpy与Tensor最大的区别就是在对GPU的支持上。Tensor只需要调用cuda()函数就可以将其转化为能在GPU上运行的类型。

我们可以通过torch.cuda.is_available()函数来判断当前的环境是否支持GPU,如果支持,则返回True。所以,为保险起见,在项目代码中一般采取“先判断,后使用”的策略来保证代码的正常运行,其基本结构如下:

import torch
 
# 定义一个3行2列的全为0的矩阵
tmp = torch.randn((3, 2))
 
# 如果支持GPU,则定义为GPU类型
if torch.cuda.is_available():
  inputs = tmp.cuda()
# 否则,定义为一般的Tensor类型
else:
  inputs = tmp

2)Variable(变量)

Pytorch里面的Variable类型数据功能更加强大,相当于是在Tensor外层套了一个壳子,这个壳子赋予了前向传播,反向传播,自动求导等功能,在计算图的构建中起的很重要的作用。Variable的结构图如下:

PyTorch基本数据类型(一)

其中最重要的两个属性是:data和grad。Data表示该变量保存的实际数据,通过该属性可以访问到它所保存的原始张量类型,而关于该 variable(变量)的梯度会被累计到.grad 上去。

在使用Variable的时候需要从torch.autograd中导入。下面通过一个例子来看一下它自动求导的过程:

import torch
from torch.autograd import Variable
 
# 定义三个Variable变量
x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
w = Variable(torch.Tensor([2, 3, 4]), requires_grad=True)
b = Variable(torch.Tensor([3, 4, 5]), requires_grad=True)
 
# 构建计算图,公式为:y = w * x^2 + b
y = w * x * x + b
 
# 自动求导,计算梯度
y.backward(torch.Tensor([1, 1, 1]))
 
print(x.grad)
print(w.grad)
print(b.grad)

上述代码的计算图为y = w * x^2 + b。对x, w, b分别求偏导为:x.grad = 2wx,w.grad=x^2,b.grad=1。代值检验可得计算结果是正确的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python自动化测试之setUp与tearDown实例
Sep 28 Python
Python3读取文件常用方法实例分析
May 22 Python
Django重装mysql后启动报错:No module named ‘MySQLdb’的解决方法
Apr 22 Python
python批量赋值操作实例
Oct 22 Python
使用python爬取抖音视频列表信息
Jul 15 Python
Python实现平行坐标图的绘制(plotly)方式
Nov 22 Python
Python实现微信好友的数据分析
Dec 16 Python
python logging 日志的级别调整方式
Feb 21 Python
TensorFlow2.1.0最新版本安装详细教程
Apr 08 Python
pandas分组聚合详解
Apr 10 Python
Python Django ORM连表正反操作技巧
Jun 13 Python
Python中tqdm的使用和例子
Sep 23 Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
You might like
php 按指定元素值去除数组元素的实现方法
2011/11/04 PHP
PHP实现把数字ID转字母ID
2013/08/12 PHP
JavaScript与Div对层定位和移动获得坐标的实现代码
2010/09/08 Javascript
屏蔽网页右键复制和ctrl+c复制的js代码
2013/01/04 Javascript
jquery操作select详解(取值,设置选中)
2014/02/07 Javascript
JavaScript DOM进阶方法
2015/04/13 Javascript
js+html5实现canvas绘制镂空字体文本的方法
2015/06/05 Javascript
javascript判断firebug是否开启的方法
2016/11/23 Javascript
canvas红包照片实例分享
2017/02/28 Javascript
微信小程序request出现400的问题解决办法
2017/05/23 Javascript
微信小程序开发之实现自定义Toast弹框
2017/06/08 Javascript
浅析JS抽象工厂模式
2017/12/14 Javascript
select2 ajax 设置默认值,初始值的方法
2018/08/09 Javascript
vue实现图片预览组件封装与使用
2019/07/13 Javascript
解决Layui数据表格的宽高问题
2019/09/28 Javascript
原生JavaScript写出Tabs标签页的实例代码
2020/07/20 Javascript
Python的gevent框架的入门教程
2015/04/29 Python
Python基础入门之seed()方法的使用
2015/05/15 Python
python实现excel读写数据
2021/03/02 Python
浅谈PYTHON 关于文件的操作
2019/03/19 Python
python自动循环定时开关机(非重启)测试
2019/08/26 Python
python实现计算器功能
2019/10/31 Python
wxPython实现列表增删改查功能
2019/11/19 Python
python pygame实现球球大作战
2019/11/25 Python
利用Tensorflow构建和训练自己的CNN来做简单的验证码识别方式
2020/01/20 Python
Python django框架开发发布会签到系统(web开发)
2020/02/12 Python
如何基于python实现年会抽奖工具
2020/10/20 Python
HTML5网页录音和上传到服务器支持PC、Android,支持IOS微信功能
2019/04/26 HTML / CSS
通过HTML5 Canvas API绘制弧线和圆形的教程
2016/03/14 HTML / CSS
速比涛英国官网:Speedo英国
2019/07/15 全球购物
什么是规则表达式
2012/05/03 面试题
毕业生个人投资创业计划书
2014/01/04 职场文书
小学少先队活动方案
2014/02/18 职场文书
关于九一八事变的演讲稿2014
2014/09/17 职场文书
担保书范文
2019/07/09 职场文书
Redis安装启动及常见数据类型
2021/04/14 Redis