详解pytorch 0.4.0迁移指南


Posted in Python onJune 16, 2019

总说

由于pytorch 0.4版本更新实在太大了, 以前版本的代码必须有一定程度的更新. 主要的更新在于 Variable和Tensor的合并., 当然还有Windows的支持, 其他一些就是支持scalar tensor以及修复bug和提升性能吧. Variable和Tensor的合并导致以前的代码会出错, 所以需要迁移, 其实迁移代价并不大.

Tensor和Variable的合并

说是合并, 其实是按照以前(0.1-0.3版本)的观点是: Tensor现在默认requires_grad=False的Variable了.torch.Tensortorch.autograd.Variable现在其实是同一个类! 没有本质的区别! 所以也就是说,现在已经没有纯粹的Tensor了, 是个Tensor, 它就支持自动求导!你现在要不要给Tensor包一下Variable, 都没有任何意义了.

查看Tensor的类型

使用.isinstance()或是x.type(), 用type()不能看tensor的具体类型.

>>> x = torch.DoubleTensor([1, 1, 1])
>>> print(type(x)) # was torch.DoubleTensor
"<class 'torch.Tensor'>"
>>> print(x.type()) # OK: 'torch.DoubleTensor'
'torch.DoubleTensor'
>>> print(isinstance(x, torch.DoubleTensor)) # OK: True
True

requires_grad 已经是Tensor的一个属性了

>>> x = torch.ones(1)
>>> x.requires_grad #默认是False
False
>>> y = torch.ones(1)
>>> z = x + y
>>> # 显然z的该属性也是False
>>> z.requires_grad
False
>>> # 所有变量都不需要grad, 所以会出错
>>> z.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>> # 可以将`requires_grad`作为一个参数, 构造tensor
>>> w = torch.ones(1, requires_grad=True)
>>> w.requires_grad
True
>>> total = w + z
>>> total.requires_grad
True
>>> # 现在可以backward了
>>> total.backward()
>>> w.grad
tensor([ 1.])
>>> # x,y,z都是不需要梯度的,他们的grad也没有计算
>>> z.grad == x.grad == y.grad == None
True

通过.requires_grad()来进行使得Tensor需要梯度.

不要随便用.data

以前.data是为了拿到Variable中的Tensor,但是后来, 两个都合并了. 所以.data返回一个新的requires_grad=False的Tensor!然而新的这个Tensor与以前那个Tensor是共享内存的. 所以不安全, 因为

y = x.data # x需要进行autograd
# y和x是共享内存的,但是这里y已经不需要grad了, 
# 所以会导致本来需要计算梯度的x也没有梯度可以计算.从而x不会得到更新!

所以, 推荐用x.detach(), 这个仍旧是共享内存的, 也是使得y的requires_grad为False,但是,如果x需要求导, 仍旧是可以自动求导的!

scalar的支持

这个非常重要啊!以前indexing一个一维Tensor,返回的是一个number类型,但是indexing一个Variable确实返回一个size为(1,)的vector.再比如一些reduction操作, 比如tensor.sum()返回一个number, 但是variable.sum()返回的是一个size为(1,)的vector.

scalar是0-维度的Tensor, 所以我们不能简单的用以前的方法创建, 我们用一个torch.tensor注意,是小写的!

y = x.data # x需要进行autograd
# y和x是共享内存的,但是这里y已经不需要grad了, 
# 所以会导致本来需要计算梯度的x也没有梯度可以计算.从而x不会得到更新!

从上面例子可以看出, 通过引入scalar, 可以将返回值的类型进行统一.
重点:
1. 取得一个tensor的值(返回number), 用.item()
2. 创建scalar的话,需要用torch.tensor(number)
3.torch.tensor(list)也可以进行创建tensor

累加loss

以前了累加loss(为了看loss的大小)一般是用total_loss+=loss.data[0], 比较诡异的是, 为啥是.data[0]? 这是因为, 这是因为loss是一个Variable, 所以以后累加loss, 用loss.item().
这个是必须的, 如果直接加, 那么随着训练的进行, 会导致后来的loss具有非常大的graph, 可能会超内存. 然而total_loss只是用来看的, 所以没必要进行维持这个graph!

弃用volatile

现在这个flag已经没用了. 被替换成torch.no_grad(),torch.set_grad_enable(grad_mode)等函数

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...   y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

dypes,devices以及numpy-style的构造函数

dtype是data types, 对应关系如下:

详解pytorch 0.4.0迁移指南

通过.dtype可以得到

其他就是以前写device type都是用.cup()或是.cuda(), 现在独立成一个函数, 我们可以

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
    [ 0.8414, 1.7962, 1.0589],
    [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

新的创建Tensor方法

主要是可以指定dtype以及device.

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
    [ 0.8414, 1.7962, 1.0589],
    [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

用 torch.tensor来创建Tensor

这个等价于numpy.array,用途:
1.将python list的数据用来创建Tensor
2. 创建scalar

# 从列表中, 创建tensor
>>> cuda = torch.device("cuda")
>>> torch.tensor([[1], [2], [3]], dtype=torch.half, device=cuda)
tensor([[ 1],
    [ 2],
    [ 3]], device='cuda:0')

>>> torch.tensor(1)        # 创建scalar
tensor(1)

torch.*like以及torch.new_*

第一个是可以创建, shape相同, 数据类型相同.

>>> x = torch.randn(3, dtype=torch.float64)
 >>> torch.zeros_like(x)
 tensor([ 0., 0., 0.], dtype=torch.float64)
 >>> torch.zeros_like(x, dtype=torch.int)
 tensor([ 0, 0, 0], dtype=torch.int32)

当然如果是单纯想要得到属性与前者相同的Tensor, 但是shape不想要一致:

>>> x = torch.randn(3, dtype=torch.float64)
 >>> x.new_ones(2) # 属性一致
 tensor([ 1., 1.], dtype=torch.float64)
 >>> x.new_ones(4, dtype=torch.int)
 tensor([ 1, 1, 1, 1], dtype=torch.int32)

书写 device-agnostic 的代码

这个含义是, 不要显示的指定是gpu, cpu之类的. 利用.to()来执行.

# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)

迁移代码对比

以前的写法

model = MyRNN()
 if use_cuda:
   model = model.cuda()

 # train
 total_loss = 0
 for input, target in train_loader:
   input, target = Variable(input), Variable(target)
   hidden = Variable(torch.zeros(*h_shape)) # init hidden
   if use_cuda:
     input, target, hidden = input.cuda(), target.cuda(), hidden.cuda()
   ... # get loss and optimize
   total_loss += loss.data[0]

 # evaluate
 for input, target in test_loader:
   input = Variable(input, volatile=True)
   if use_cuda:
     ...
   ...

现在的写法

# torch.device object used throughout this script
 device = torch.device("cuda" if use_cuda else "cpu")

 model = MyRNN().to(device)

 # train
 total_loss = 0
 for input, target in train_loader:
   input, target = input.to(device), target.to(device)
   hidden = input.new_zeros(*h_shape) # has the same device & dtype as `input`
   ... # get loss and optimize
   total_loss += loss.item()      # get Python number from 1-element Tensor

 # evaluate
 with torch.no_grad():          # operations inside don't track history
   for input, target in test_loader:
     ...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的zipfile模块使用详解
Jun 25 Python
python网络编程调用recv函数完整接收数据的三种方法
Mar 31 Python
Python unittest模块用法实例分析
May 25 Python
Python使用pickle模块储存对象操作示例
Aug 15 Python
Pycharm保存不能自动同步到远程服务器的解决方法
Jun 27 Python
python挖矿算力测试程序详解
Jul 03 Python
springboot配置文件抽离 git管理统 配置中心详解
Sep 02 Python
python生成器用法实例详解
Nov 22 Python
python脚本后台执行方式
Dec 21 Python
对django 2.x版本中models.ForeignKey()外键说明介绍
Mar 30 Python
Pytest中skip和skipif的具体使用方法
Jun 30 Python
分享提高 Python 代码的可读性的技巧
Mar 03 Python
对pyqt5多线程正确的开启姿势详解
Jun 14 #Python
Python+PyQT5的子线程更新UI界面的实例
Jun 14 #Python
在PYQT5中QscrollArea(滚动条)的使用方法
Jun 14 #Python
PYQT5设置textEdit自动滚屏的方法
Jun 14 #Python
使用PyQt4 设置TextEdit背景的方法
Jun 14 #Python
Ubuntu18.04中Python2.7与Python3.6环境切换
Jun 14 #Python
ubuntu 16.04下python版本切换的方法
Jun 14 #Python
You might like
如何解决CI框架的Disallowed Key Characters错误提示
2013/07/05 PHP
最新制作ThinkPHP3.2.3完全开发手册
2015/11/23 PHP
根据分辩率调用不同的CSS.
2007/01/08 Javascript
JavaScript 模仿vbs中的 DateAdd() 函数的代码
2007/08/13 Javascript
滚动条变色 隐藏滚动条与双击网页自动滚屏显示代码
2009/12/28 Javascript
基于jquery的无限级联下拉框js插件
2011/10/29 Javascript
js图片向右一张张滚动效果实例代码
2013/11/23 Javascript
一个JavaScript获取元素当前高度的实例
2014/10/29 Javascript
node.js中的console.dir方法使用说明
2014/12/10 Javascript
jQuery版本升级踩坑大全
2016/01/12 Javascript
angular分页指令操作
2017/01/09 Javascript
JS中静态页面实现微信分享功能
2017/02/06 Javascript
总结JavaScript在IE9之前版本中内存泄露问题
2018/04/28 Javascript
Angularjs实现多图片上传预览功能
2018/07/18 Javascript
vue+iview 兼容IE11浏览器的实现方法
2019/01/07 Javascript
JavaScript Date对象功能与用法学习记录
2020/04/28 Javascript
js实现滑动滑块验证登录
2020/07/24 Javascript
[01:09]DOTA2次级职业联赛 - 99战队宣传片
2014/12/01 DOTA
[51:17]VGJ.T vs Mineski 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
python网络编程学习笔记(四):域名系统
2014/06/09 Python
自己编程中遇到的Python错误和解决方法汇总整理
2015/06/03 Python
在Python中分别打印列表中的每一个元素方法
2018/11/07 Python
使用matplotlib动态刷新指定曲线实例
2020/04/23 Python
接口自动化多层嵌套json数据处理代码实例
2020/11/20 Python
老生常谈CSS中的长度单位
2016/06/27 HTML / CSS
Holiday Inn中国官网:IHG旗下假日酒店预订
2018/04/08 全球购物
美国最大的在线寄售和旧货店:Swap.com
2018/08/27 全球购物
Sunglass Hut巴西网上商店:男女太阳镜
2020/10/04 全球购物
十八大感想感言
2014/02/10 职场文书
纠风工作实施方案
2014/03/15 职场文书
2014年五四青年节演讲稿范文
2014/04/22 职场文书
个人股份转让协议书范本
2015/01/28 职场文书
幼儿园教师个人总结
2015/02/05 职场文书
毕业答辩开场白范文
2015/05/27 职场文书
比赛主持人开场白
2015/05/29 职场文书
详解Mysql数据库平滑扩容解决高并发和大数据量问题
2022/05/25 MySQL