详解pytorch 0.4.0迁移指南


Posted in Python onJune 16, 2019

总说

由于pytorch 0.4版本更新实在太大了, 以前版本的代码必须有一定程度的更新. 主要的更新在于 Variable和Tensor的合并., 当然还有Windows的支持, 其他一些就是支持scalar tensor以及修复bug和提升性能吧. Variable和Tensor的合并导致以前的代码会出错, 所以需要迁移, 其实迁移代价并不大.

Tensor和Variable的合并

说是合并, 其实是按照以前(0.1-0.3版本)的观点是: Tensor现在默认requires_grad=False的Variable了.torch.Tensortorch.autograd.Variable现在其实是同一个类! 没有本质的区别! 所以也就是说,现在已经没有纯粹的Tensor了, 是个Tensor, 它就支持自动求导!你现在要不要给Tensor包一下Variable, 都没有任何意义了.

查看Tensor的类型

使用.isinstance()或是x.type(), 用type()不能看tensor的具体类型.

>>> x = torch.DoubleTensor([1, 1, 1])
>>> print(type(x)) # was torch.DoubleTensor
"<class 'torch.Tensor'>"
>>> print(x.type()) # OK: 'torch.DoubleTensor'
'torch.DoubleTensor'
>>> print(isinstance(x, torch.DoubleTensor)) # OK: True
True

requires_grad 已经是Tensor的一个属性了

>>> x = torch.ones(1)
>>> x.requires_grad #默认是False
False
>>> y = torch.ones(1)
>>> z = x + y
>>> # 显然z的该属性也是False
>>> z.requires_grad
False
>>> # 所有变量都不需要grad, 所以会出错
>>> z.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>> # 可以将`requires_grad`作为一个参数, 构造tensor
>>> w = torch.ones(1, requires_grad=True)
>>> w.requires_grad
True
>>> total = w + z
>>> total.requires_grad
True
>>> # 现在可以backward了
>>> total.backward()
>>> w.grad
tensor([ 1.])
>>> # x,y,z都是不需要梯度的,他们的grad也没有计算
>>> z.grad == x.grad == y.grad == None
True

通过.requires_grad()来进行使得Tensor需要梯度.

不要随便用.data

以前.data是为了拿到Variable中的Tensor,但是后来, 两个都合并了. 所以.data返回一个新的requires_grad=False的Tensor!然而新的这个Tensor与以前那个Tensor是共享内存的. 所以不安全, 因为

y = x.data # x需要进行autograd
# y和x是共享内存的,但是这里y已经不需要grad了, 
# 所以会导致本来需要计算梯度的x也没有梯度可以计算.从而x不会得到更新!

所以, 推荐用x.detach(), 这个仍旧是共享内存的, 也是使得y的requires_grad为False,但是,如果x需要求导, 仍旧是可以自动求导的!

scalar的支持

这个非常重要啊!以前indexing一个一维Tensor,返回的是一个number类型,但是indexing一个Variable确实返回一个size为(1,)的vector.再比如一些reduction操作, 比如tensor.sum()返回一个number, 但是variable.sum()返回的是一个size为(1,)的vector.

scalar是0-维度的Tensor, 所以我们不能简单的用以前的方法创建, 我们用一个torch.tensor注意,是小写的!

y = x.data # x需要进行autograd
# y和x是共享内存的,但是这里y已经不需要grad了, 
# 所以会导致本来需要计算梯度的x也没有梯度可以计算.从而x不会得到更新!

从上面例子可以看出, 通过引入scalar, 可以将返回值的类型进行统一.
重点:
1. 取得一个tensor的值(返回number), 用.item()
2. 创建scalar的话,需要用torch.tensor(number)
3.torch.tensor(list)也可以进行创建tensor

累加loss

以前了累加loss(为了看loss的大小)一般是用total_loss+=loss.data[0], 比较诡异的是, 为啥是.data[0]? 这是因为, 这是因为loss是一个Variable, 所以以后累加loss, 用loss.item().
这个是必须的, 如果直接加, 那么随着训练的进行, 会导致后来的loss具有非常大的graph, 可能会超内存. 然而total_loss只是用来看的, 所以没必要进行维持这个graph!

弃用volatile

现在这个flag已经没用了. 被替换成torch.no_grad(),torch.set_grad_enable(grad_mode)等函数

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...   y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

dypes,devices以及numpy-style的构造函数

dtype是data types, 对应关系如下:

详解pytorch 0.4.0迁移指南

通过.dtype可以得到

其他就是以前写device type都是用.cup()或是.cuda(), 现在独立成一个函数, 我们可以

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
    [ 0.8414, 1.7962, 1.0589],
    [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

新的创建Tensor方法

主要是可以指定dtype以及device.

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
    [ 0.8414, 1.7962, 1.0589],
    [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

用 torch.tensor来创建Tensor

这个等价于numpy.array,用途:
1.将python list的数据用来创建Tensor
2. 创建scalar

# 从列表中, 创建tensor
>>> cuda = torch.device("cuda")
>>> torch.tensor([[1], [2], [3]], dtype=torch.half, device=cuda)
tensor([[ 1],
    [ 2],
    [ 3]], device='cuda:0')

>>> torch.tensor(1)        # 创建scalar
tensor(1)

torch.*like以及torch.new_*

第一个是可以创建, shape相同, 数据类型相同.

>>> x = torch.randn(3, dtype=torch.float64)
 >>> torch.zeros_like(x)
 tensor([ 0., 0., 0.], dtype=torch.float64)
 >>> torch.zeros_like(x, dtype=torch.int)
 tensor([ 0, 0, 0], dtype=torch.int32)

当然如果是单纯想要得到属性与前者相同的Tensor, 但是shape不想要一致:

>>> x = torch.randn(3, dtype=torch.float64)
 >>> x.new_ones(2) # 属性一致
 tensor([ 1., 1.], dtype=torch.float64)
 >>> x.new_ones(4, dtype=torch.int)
 tensor([ 1, 1, 1, 1], dtype=torch.int32)

书写 device-agnostic 的代码

这个含义是, 不要显示的指定是gpu, cpu之类的. 利用.to()来执行.

# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)

迁移代码对比

以前的写法

model = MyRNN()
 if use_cuda:
   model = model.cuda()

 # train
 total_loss = 0
 for input, target in train_loader:
   input, target = Variable(input), Variable(target)
   hidden = Variable(torch.zeros(*h_shape)) # init hidden
   if use_cuda:
     input, target, hidden = input.cuda(), target.cuda(), hidden.cuda()
   ... # get loss and optimize
   total_loss += loss.data[0]

 # evaluate
 for input, target in test_loader:
   input = Variable(input, volatile=True)
   if use_cuda:
     ...
   ...

现在的写法

# torch.device object used throughout this script
 device = torch.device("cuda" if use_cuda else "cpu")

 model = MyRNN().to(device)

 # train
 total_loss = 0
 for input, target in train_loader:
   input, target = input.to(device), target.to(device)
   hidden = input.new_zeros(*h_shape) # has the same device & dtype as `input`
   ... # get loss and optimize
   total_loss += loss.item()      # get Python number from 1-element Tensor

 # evaluate
 with torch.no_grad():          # operations inside don't track history
   for input, target in test_loader:
     ...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用 Python 获取 Linux 系统信息的代码
Jul 13 Python
Python中的ctime()方法使用教程
May 22 Python
python 获取文件下所有文件或目录os.walk()的实例
Apr 23 Python
解决python中使用plot画图,图不显示的问题
Jul 04 Python
对pycharm代码整体左移和右移缩进快捷键的介绍
Jul 16 Python
python如何实现数据的线性拟合
Jul 19 Python
python读写csv文件的方法
Aug 13 Python
python画蝴蝶曲线图的实例
Nov 21 Python
Python之指数与E记法的区别详解
Nov 21 Python
python代码实现图书管理系统
Nov 30 Python
Python实现简单猜数字游戏
Feb 03 Python
理解python中装饰器的作用
Jul 21 Python
对pyqt5多线程正确的开启姿势详解
Jun 14 #Python
Python+PyQT5的子线程更新UI界面的实例
Jun 14 #Python
在PYQT5中QscrollArea(滚动条)的使用方法
Jun 14 #Python
PYQT5设置textEdit自动滚屏的方法
Jun 14 #Python
使用PyQt4 设置TextEdit背景的方法
Jun 14 #Python
Ubuntu18.04中Python2.7与Python3.6环境切换
Jun 14 #Python
ubuntu 16.04下python版本切换的方法
Jun 14 #Python
You might like
php Ubb代码编辑器函数代码
2012/07/05 PHP
PHP实现从远程下载文件的方法
2015/03/12 PHP
简单解决新浪SAE无法上传文件的问题
2015/05/13 PHP
比较新旧两个数组值得增加和删除的JS代码
2013/10/30 Javascript
使用jquery清空、复位整个输入域
2015/04/02 Javascript
IE10中flexigrid无法显示数据的解决方法
2015/07/26 Javascript
JavaScript深度复制(deep clone)的实现方法
2016/02/19 Javascript
js仿支付宝填写支付密码效果实现多方框输入密码
2016/03/09 Javascript
移动端jQuery修正Web页面滑动时div问题的两则实例
2016/05/30 Javascript
微信小程序 Storage API实例详解
2016/10/02 Javascript
在vue中封装可复用的组件方法
2018/03/01 Javascript
vue实现商品加减计算总价的实例代码
2018/08/12 Javascript
vue.js的vue-cli脚手架中使用百度地图API的实例
2019/01/21 Javascript
node.js基于dgram数据报模块创建UDP服务器和客户端操作示例
2020/02/12 Javascript
微信小程序实现搜索功能
2020/03/10 Javascript
vue 实现把路由单独分离出来
2020/08/13 Javascript
vue中v-model对select的绑定操作
2020/08/31 Javascript
vue3.0中setup使用(两种用法)
2020/12/02 Vue.js
基于VUE实现简单的学生信息管理系统
2021/01/13 Vue.js
实例讲解Python中的私有属性
2014/08/21 Python
Python应用库大全总结
2018/05/30 Python
python tornado微信开发入门代码
2018/08/24 Python
django DRF图片路径问题的解决方法
2018/09/10 Python
解决项目pycharm能运行,在终端却无法运行的问题
2019/01/19 Python
django有外键关系的两张表如何相互查找
2020/02/10 Python
使用pyecharts1.7进行简单的可视化大全
2020/05/17 Python
比利时网上药店: Drogisterij.net
2017/03/17 全球购物
CHARLES & KEITH台湾官网:新加坡时尚品牌
2019/07/30 全球购物
JACK & JONES荷兰官网:男士服装和鞋子
2021/03/07 全球购物
阿联酋最好的手机、电子产品和家用电器网上商店:Eros Digital Home
2020/08/09 全球购物
建筑公司文秘岗位职责
2013/11/29 职场文书
《春笋》教学反思
2014/04/15 职场文书
公司慰问信范文
2015/03/23 职场文书
Python面向对象编程之类的概念
2021/11/01 Python
nginx刷新页面出现404解决方案(亲测有效)
2022/03/18 Servers
JS前端使用canvas实现扩展物体类和事件派发
2022/08/05 Javascript