PyTorch基本数据类型(一)


Posted in Python onMay 22, 2019

PyTorch基础入门一:PyTorch基本数据类型

1)Tensor(张量)

Pytorch里面处理的最基本的操作对象就是Tensor(张量),它表示的其实就是一个多维矩阵,并有矩阵相关的运算操作。在使用上和numpy是对应的,它和numpy唯一的不同就是,pytorch可以在GPU上运行,而numpy不可以。所以,我们也可以使用Tensor来代替numpy的使用。当然,二者也可以相互转换。

Tensor的基本数据类型有五种:

  • 32位浮点型:torch.FloatTensor。pyorch.Tensor()默认的就是这种类型。
  • 64位整型:torch.LongTensor。
  • 32位整型:torch.IntTensor。
  • 16位整型:torch.ShortTensor。
  • 64位浮点型:torch.DoubleTensor。

那么如何定义Tensor张量呢?其实定义的方式和numpy一样,直接传入相应的矩阵即可即可。下面就定义了一个三行两列的矩阵:

import torch
# 导包
 
a = torch.Tensor([[1, 2], [3, 4], [5, 6]])
print(a)

不过在项目之中,更多的做法是以特殊值或者随机值初始化一个矩阵,就像下面这样:

import torch
 
# 定义一个3行2列的全为0的矩阵
b = torch.zeros((3, 2))
 
# 定义一个3行2列的随机值矩阵
c = torch.randn((3, 2))
 
# 定义一个3行2列全为1的矩阵
d = torch.ones((3, 2))
 
print(b)
print(c)
print(d)

Tensor和numpy.ndarray之间还可以相互转换,其方式如下:

  • Numpy转化为Tensor:torch.from_numpy(numpy矩阵)
  • Tensor转化为numpy:Tensor矩阵.numpy()

范例如下:

import torch
import numpy as np
 
# 定义一个3行2列的全为0的矩阵
b = torch.randn((3, 2))
 
# tensor转化为numpy
numpy_b = b.numpy()
print(numpy_b)
 
# numpy转化为tensor
numpy_e = np.array([[1, 2], [3, 4], [5, 6]])
torch_e = torch.from_numpy(numpy_e)
 
print(numpy_e)
print(torch_e)

之前说过,numpy与Tensor最大的区别就是在对GPU的支持上。Tensor只需要调用cuda()函数就可以将其转化为能在GPU上运行的类型。

我们可以通过torch.cuda.is_available()函数来判断当前的环境是否支持GPU,如果支持,则返回True。所以,为保险起见,在项目代码中一般采取“先判断,后使用”的策略来保证代码的正常运行,其基本结构如下:

import torch
 
# 定义一个3行2列的全为0的矩阵
tmp = torch.randn((3, 2))
 
# 如果支持GPU,则定义为GPU类型
if torch.cuda.is_available():
  inputs = tmp.cuda()
# 否则,定义为一般的Tensor类型
else:
  inputs = tmp

2)Variable(变量)

Pytorch里面的Variable类型数据功能更加强大,相当于是在Tensor外层套了一个壳子,这个壳子赋予了前向传播,反向传播,自动求导等功能,在计算图的构建中起的很重要的作用。Variable的结构图如下:

PyTorch基本数据类型(一)

其中最重要的两个属性是:data和grad。Data表示该变量保存的实际数据,通过该属性可以访问到它所保存的原始张量类型,而关于该 variable(变量)的梯度会被累计到.grad 上去。

在使用Variable的时候需要从torch.autograd中导入。下面通过一个例子来看一下它自动求导的过程:

import torch
from torch.autograd import Variable
 
# 定义三个Variable变量
x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
w = Variable(torch.Tensor([2, 3, 4]), requires_grad=True)
b = Variable(torch.Tensor([3, 4, 5]), requires_grad=True)
 
# 构建计算图,公式为:y = w * x^2 + b
y = w * x * x + b
 
# 自动求导,计算梯度
y.backward(torch.Tensor([1, 1, 1]))
 
print(x.grad)
print(w.grad)
print(b.grad)

上述代码的计算图为y = w * x^2 + b。对x, w, b分别求偏导为:x.grad = 2wx,w.grad=x^2,b.grad=1。代值检验可得计算结果是正确的。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python pickle 和 shelve模块的用法
Sep 16 Python
python进程管理工具supervisor使用实例
Sep 17 Python
Python实现抓取网页并且解析的实例
Sep 20 Python
让python 3支持mysqldb的解决方法
Feb 14 Python
python pandas dataframe 按列或者按行合并的方法
Apr 12 Python
Python定时发送消息的脚本:每天跟你女朋友说晚安
Oct 21 Python
使用python批量读取word文档并整理关键信息到excel表格的实例
Nov 07 Python
Python 分享10个PyCharm技巧
Jul 13 Python
python3的数据类型及数据类型转换实例详解
Aug 20 Python
Python3并发写文件与Python对比
Nov 20 Python
Python如何读取、写入JSON数据
Jul 28 Python
PyTorch dropout设置训练和测试模式的实现
May 27 Python
PyTorch搭建多项式回归模型(三)
May 22 #Python
pytorch使用Variable实现线性回归
May 21 #Python
Python面向对象进阶学习
May 21 #Python
谈一谈基于python的面向对象编程基础
May 21 #Python
python字符串和常用数据结构知识总结
May 21 #Python
Opencv实现抠图背景图替换功能
May 21 #Python
python多进程读图提取特征存npy
May 21 #Python
You might like
Mysql和网页显示乱码解决方法集锦
2008/03/27 PHP
php Calender(日历)代码分享
2014/01/03 PHP
yii 框架实现按天,月,年,自定义时间段统计数据的方法分析
2020/04/04 PHP
纯js实现的论坛常用的运行代码的效果
2008/07/15 Javascript
Document 对象的常用方法
2009/07/31 Javascript
7个Javascript地图脚本整理
2009/10/20 Javascript
JS获取dom 对象 ajax操作 读写cookie函数
2009/11/18 Javascript
基于JQuery.timer插件实现一个计时器
2010/04/25 Javascript
JS 获取HTML标签内的子节点的方法
2016/09/21 Javascript
JavaScript实现的微信二维码图片生成器的示例
2016/10/26 Javascript
jQuery实现复制到粘贴板功能
2017/02/11 Javascript
3分钟掌握常用的JS操作JSON方法总结
2017/04/25 Javascript
BootStrap 页签切换失效的解决方法
2017/08/17 Javascript
bootstrap下拉框动态赋值方法
2018/08/10 Javascript
详解Vue demo实现商品列表的展示
2019/05/07 Javascript
javascript实现拖拽碰撞检测
2020/03/12 Javascript
javascript贪吃蛇游戏设计与实现
2020/09/17 Javascript
vue的webcamjs集成方式
2020/11/16 Javascript
Python入门篇之编程习惯与特点
2014/10/17 Python
Python __setattr__、 __getattr__、 __delattr__、__call__用法示例
2015/03/06 Python
Python实现模拟登录及表单提交的方法
2015/07/25 Python
Django自定义分页与bootstrap分页结合
2021/02/22 Python
Python三级菜单的实例
2017/09/13 Python
Python定时从Mysql提取数据存入Redis的实现
2020/05/03 Python
Django --Xadmin 判断登录者身份实例
2020/07/03 Python
推荐值得学习的12款python-web开发框架
2020/08/10 Python
CSS3中Animation属性的使用详解
2015/08/06 HTML / CSS
玩具反斗城葡萄牙官方商城:Toys"R"Us葡萄牙
2016/10/21 全球购物
美国定制钻石订婚戒指:Ritani
2017/12/08 全球购物
英国家具、照明、家居用品网上商店:Wayfair.co.uk
2020/02/13 全球购物
为什么group by 和order by会使查询变慢
2014/05/16 面试题
六一节目主持词
2014/04/01 职场文书
2014年政府采购工作总结
2014/12/09 职场文书
劳资员岗位职责
2015/02/13 职场文书
毕业晚宴祝酒词
2015/08/11 职场文书
css布局巧妙技巧之css三角示例的运用
2022/03/16 HTML / CSS