Pytorch 实现变量类型转换


Posted in Python onMay 17, 2021

Pytorch的数据类型为各式各样的Tensor,Tensor可以理解为高维矩阵。

与Numpy中的Array类似。Pytorch中的tensor又包括CPU上的数据类型和GPU上的数据类型,一般GPU上的Tensor是CPU上的Tensor加cuda()函数得到。通过使用Type函数可以查看变量类型。

一般系统默认是torch.FloatTensor类型。

例如data = torch.Tensor(2,3)是一个2*3的张量,类型为FloatTensor; data.cuda()就转换为GPU的张量类型,torch.cuda.FloatTensor类型。

下面简单介绍一下Pytorch中变量之间的相互转换

(1)CPU或GPU张量之间的转换

一般只要在Tensor后加long(), int(), double(),float(),byte()等函数就能将Tensor进行类型转换;

例如:Torch.LongTensor--->Torch.FloatTensor, 直接使用data.float()即可

还可以使用type()函数,data为Tensor数据类型,data.type()为给出data的类型,如果使用data.type(torch.FloatTensor)则强制转换为torch.FloatTensor类型张量。

当你不知道要转换为什么类型时,但需要求a1,a2两个张量的乘积,可以使用a1.type_as(a2)将a1转换为a2同类型。

(2)CPU张量 ----> GPU张量, 使用data.cuda()

(3)GPU张量 ----> CPU张量 使用data.cpu()

(4)Variable变量转换成普通的Tensor,其实可以理解Variable为一个Wrapper,里头的data就是Tensor. 如果Var是Variable变量,使用Var.data获得Tensor变量

(5)Tensor与Numpy Array之间的转换

Tensor---->Numpy 可以使用 data.numpy(),data为Tensor变量

Numpy ----> Tensor 可以使用torch.from_numpy(data),data为numpy变量

补充:Numpy/Pytorch之数据类型与强制类型转换

1.数据类型简介

Numpy

NumPy 支持比 Python 更多种类的数值类型。 下表显示了 NumPy 中定义的不同标量数据类型。

序号 数据类型及描述
1. bool_存储为一个字节的布尔值(真或假)
2. int_默认整数,相当于 C 的long,通常为int32或int64
3. intc相当于 C 的int,通常为int32或int64
4. intp用于索引的整数,相当于 C 的size_t,通常为int32或int64
5. int8字节(-128 ~ 127)
6. int1616 位整数(-32768 ~ 32767)
7. int3232 位整数(-2147483648 ~ 2147483647)
8. int6464 位整数(-9223372036854775808 ~ 9223372036854775807)
9. uint88 位无符号整数(0 ~ 255)
10. uint1616 位无符号整数(0 ~ 65535)
11. uint3232 位无符号整数(0 ~ 4294967295)
12. uint6464 位无符号整数(0 ~ 18446744073709551615)
13. float_float64的简写
14. float16半精度浮点:符号位,5 位指数,10 位尾数
15. float32单精度浮点:符号位,8 位指数,23 位尾数
16. float64双精度浮点:符号位,11 位指数,52 位尾数
17. complex_complex128的简写
18. complex64复数,由两个 32 位浮点表示(实部和虚部)
19.

complex128复数,由两个 64 位浮点表示(实部和虚部)

直接使用类型名很可能会报错,正确的使用方式是np.调用,eg, np.uint8

Pytorch

Torch定义了七种CPU张量类型和八种GPU张量类型,这里我们就只讲解一下CPU中的,其实GPU中只是中间加一个cuda即可,如torch.cuda.FloatTensor:

torch.FloatTensor(2,3) 构建一个2*3 Float类型的张量

torch.DoubleTensor(2,3) 构建一个2*3 Double类型的张量

torch.ByteTensor(2,3) 构建一个2*3 Byte类型的张量

torch.CharTensor(2,3) 构建一个2*3 Char类型的张量

torch.ShortTensor(2,3) 构建一个2*3 Short类型的张量

torch.IntTensor(2,3) 构建一个2*3 Int类型的张量

torch.LongTensor(2,3) 构建一个2*3 Long类型的张量

同样,直接使用类型名很可能会报错,正确的使用方式是torch.调用,eg,torch.FloatTensor()

2.Python的type()函数

type函数可以由变量调用,或者把变量作为参数传入。

返回的是该变量的类型,而非数据类型。

data = np.random.randint(0, 255, 300)
print(type(data))

输出

<class 'numpy.ndarray'>

3.Numpy/Pytorch的dtype属性

返回值为变量的数据类型

t_out = torch.Tensor(1,2,3)
print(t_out.dtype)

输出

torch.float32

t_out = torch.Tensor(1,2,3)

print(t_out.numpy().dtype)

输出

float32

4.Numpy中的类型转换

先聊聊我为什么会用到这个函数(不看跳过)

为了实施trochvision.transforms.ToPILImage()函数

于是我想从numpy的ndarray类型转成PILImage类型

我做了以下尝试

data = np.random.randint(0, 255, 300)
n_out = data.reshape(10,10,3)
print(n_out.dtype)
img = transforms.ToPILImage()(n_out)
img.show()

但是很遗憾,报错了

raise TypeError('Input type {} is not supported'.format(npimg.dtype))

TypeError: Input type int32 is not supported

因为要将ndarray转成PILImage要求ndarray是uint8类型的。

于是我认输了。。。

使用了

n_out = np.linspace(0,255,300,dtype=np.uint8)
n_out = n_out.reshape(10,10,3)
print(n_out.dtype)
img = torchvision.transforms.ToPILImage()(n_out)
img.show()

得到了输出

uint8

Pytorch 实现变量类型转换

嗯,显示了一张图片

但是呢,就很憋屈,和想要的随机数效果不一样。

于是我用了astype函数

astype()函数

由变量调用,但是直接调用不会改变原变量的数据类型,是返回值是改变类型后的新变量,所以要赋值回去。

n_out = n_out.astype(np.uint8)
#初始化随机数种子
np.random.seed(0)
 
data = np.random.randint(0, 255, 300)
print(data.dtype)
n_out = data.reshape(10,10,3)
 
#强制类型转换
n_out = n_out.astype(np.uint8)
print(n_out.dtype)
 
img = transforms.ToPILImage()(n_out)
img.show()

输出

int32

uint8

Pytorch 实现变量类型转换

5.Pytorch中的类型转换

pytorch中没有astype函数,正确的转换方法是

Way1 : 变量直接调用类型

tensor = torch.Tensor(3, 5)

torch.long() 将tensor投射为long类型

newtensor = tensor.long()

torch.half()将tensor投射为半精度浮点类型

newtensor = tensor.half()

torch.int()将该tensor投射为int类型

newtensor = tensor.int()

torch.double()将该tensor投射为double类型

newtensor = tensor.double()

torch.float()将该tensor投射为float类型

newtensor = tensor.float()

torch.char()将该tensor投射为char类型

newtensor = tensor.char()

torch.byte()将该tensor投射为byte类型

newtensor = tensor.byte()

torch.short()将该tensor投射为short类型

newtensor = tensor.short()

同样,和numpy中的astype函数一样,是返回值才是改变类型后的结果,调用的变量类型不变

Way2 : 变量调用pytorch中的type函数

type(new_type=None, async=False)如果未提供new_type,则返回类型,否则将此对象转换为指定的类型。 如果已经是正确的类型,则不会执行且返回原对象。

用法如下:

self = torch.LongTensor(3, 5)
# 转换为其他类型
print self.type(torch.FloatTensor)

Way3 : 变量调用pytorch中的type_as函数

如果张量已经是正确的类型,则不会执行操作。具体操作方法如下:

self = torch.Tensor(3, 5)
tesnor = torch.IntTensor(2,3)
print self.type_as(tesnor)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python采用requests库模拟登录和抓取数据的简单示例
Jul 05 Python
用Python代码来绘制彭罗斯点阵的教程
Apr 03 Python
python实现自动重启本程序的方法
Jul 09 Python
5种Python单例模式的实现方式
Jan 14 Python
Python第三方Window模块文件的几种安装方法
Nov 22 Python
Apache,wsgi,django 程序部署配置方法详解
Jul 01 Python
Form表单及django的form表单的补充
Jul 25 Python
Django处理Ajax发送的Get请求代码详解
Jul 29 Python
python 异步async库的使用说明
May 04 Python
PyCharm中如何直接使用Anaconda已安装的库
May 28 Python
python 爬虫之selenium可视化爬虫的实现
Dec 04 Python
Python List remove()实例用法详解
Aug 02 Python
Python进度条的使用
May 17 #Python
Python包管理工具pip的15 个使用小技巧
Python中json.dumps()函数的使用解析
May 17 #Python
Python中threading库实现线程锁与释放锁
Python中Cookies导出某站用户数据的方法
May 17 #Python
Python 高级库15 个让新手爱不释手(推荐)
Python带你从浅入深探究Tuple(基础篇)
May 15 #Python
You might like
解决中英文字符串长度问题函数
2007/01/16 PHP
删除及到期域名的查看(抢域名必备哦)
2008/05/14 PHP
PHP file_exists问题杂谈
2012/05/07 PHP
克隆一个新项目的快捷方式
2013/04/10 PHP
PHP轻量级数据库操作类Medoo增加、删除、修改、查询例子
2014/07/04 PHP
深入理解PHP中的Streams工具
2015/07/03 PHP
twig模板获取全局变量的方法
2016/02/05 PHP
php 防止表单重复提交两种实现方法
2016/11/03 PHP
php遍历目录下文件并按修改时间排序操作示例
2019/07/12 PHP
使用Rancher在K8S上部署高性能PHP应用程序的教程
2020/07/10 PHP
JavaScript 事件查询综合
2009/07/13 Javascript
jquery last-child 列表最后一项的样式
2010/01/22 Javascript
JavaScript之Getters和Setters 平台支持等详细介绍
2012/12/07 Javascript
JS自动适应的图片弹窗实例
2013/06/29 Javascript
JavaScript显示当然日期和时间即年月日星期和时间
2013/10/29 Javascript
实例详解JavaScript获取链接参数的方法
2016/01/01 Javascript
jQuery的Each比JS原生for循环性能慢很多的原因
2016/07/05 Javascript
AngularJS $injector 依赖注入详解
2016/09/14 Javascript
Angular 4 依赖注入学习教程之FactoryProvider的使用(四)
2017/06/04 Javascript
微信页面弹出键盘后iframe内容变空白的解决方案
2017/09/20 Javascript
js实现mp3录音通过websocket实时传送+简易波形图效果
2020/06/12 Javascript
vue单元格多列合并的实现
2020/11/26 Vue.js
[03:17]2014DOTA2 国际邀请赛中国区预选赛 四强专访
2014/05/23 DOTA
python操作ie登陆土豆网的方法
2015/05/09 Python
python读取与写入csv格式文件的示例代码
2017/12/16 Python
基于python实现KNN分类算法
2020/04/23 Python
python实现鸢尾花三种聚类算法(K-means,AGNES,DBScan)
2019/06/27 Python
pandas取出重复数据的方法
2019/07/04 Python
Python:二维列表下标互换方式(矩阵转置)
2019/12/02 Python
用python读取xlsx文件
2020/12/17 Python
AmazeUI图片轮播效果的示例代码
2020/08/20 HTML / CSS
应届毕业生求职信范文分享
2013/12/26 职场文书
国庆促销活动总结
2014/08/29 职场文书
天坛导游词
2015/02/02 职场文书
Java对文件的读写操作方法
2022/04/29 Java/Android
码云(gitee)通过git自动同步到阿里云服务器
2022/12/24 Servers