详解pytorch 0.4.0迁移指南


Posted in Python onJune 16, 2019

总说

由于pytorch 0.4版本更新实在太大了, 以前版本的代码必须有一定程度的更新. 主要的更新在于 Variable和Tensor的合并., 当然还有Windows的支持, 其他一些就是支持scalar tensor以及修复bug和提升性能吧. Variable和Tensor的合并导致以前的代码会出错, 所以需要迁移, 其实迁移代价并不大.

Tensor和Variable的合并

说是合并, 其实是按照以前(0.1-0.3版本)的观点是: Tensor现在默认requires_grad=False的Variable了.torch.Tensortorch.autograd.Variable现在其实是同一个类! 没有本质的区别! 所以也就是说,现在已经没有纯粹的Tensor了, 是个Tensor, 它就支持自动求导!你现在要不要给Tensor包一下Variable, 都没有任何意义了.

查看Tensor的类型

使用.isinstance()或是x.type(), 用type()不能看tensor的具体类型.

>>> x = torch.DoubleTensor([1, 1, 1])
>>> print(type(x)) # was torch.DoubleTensor
"<class 'torch.Tensor'>"
>>> print(x.type()) # OK: 'torch.DoubleTensor'
'torch.DoubleTensor'
>>> print(isinstance(x, torch.DoubleTensor)) # OK: True
True

requires_grad 已经是Tensor的一个属性了

>>> x = torch.ones(1)
>>> x.requires_grad #默认是False
False
>>> y = torch.ones(1)
>>> z = x + y
>>> # 显然z的该属性也是False
>>> z.requires_grad
False
>>> # 所有变量都不需要grad, 所以会出错
>>> z.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>> # 可以将`requires_grad`作为一个参数, 构造tensor
>>> w = torch.ones(1, requires_grad=True)
>>> w.requires_grad
True
>>> total = w + z
>>> total.requires_grad
True
>>> # 现在可以backward了
>>> total.backward()
>>> w.grad
tensor([ 1.])
>>> # x,y,z都是不需要梯度的,他们的grad也没有计算
>>> z.grad == x.grad == y.grad == None
True

通过.requires_grad()来进行使得Tensor需要梯度.

不要随便用.data

以前.data是为了拿到Variable中的Tensor,但是后来, 两个都合并了. 所以.data返回一个新的requires_grad=False的Tensor!然而新的这个Tensor与以前那个Tensor是共享内存的. 所以不安全, 因为

y = x.data # x需要进行autograd
# y和x是共享内存的,但是这里y已经不需要grad了, 
# 所以会导致本来需要计算梯度的x也没有梯度可以计算.从而x不会得到更新!

所以, 推荐用x.detach(), 这个仍旧是共享内存的, 也是使得y的requires_grad为False,但是,如果x需要求导, 仍旧是可以自动求导的!

scalar的支持

这个非常重要啊!以前indexing一个一维Tensor,返回的是一个number类型,但是indexing一个Variable确实返回一个size为(1,)的vector.再比如一些reduction操作, 比如tensor.sum()返回一个number, 但是variable.sum()返回的是一个size为(1,)的vector.

scalar是0-维度的Tensor, 所以我们不能简单的用以前的方法创建, 我们用一个torch.tensor注意,是小写的!

y = x.data # x需要进行autograd
# y和x是共享内存的,但是这里y已经不需要grad了, 
# 所以会导致本来需要计算梯度的x也没有梯度可以计算.从而x不会得到更新!

从上面例子可以看出, 通过引入scalar, 可以将返回值的类型进行统一.
重点:
1. 取得一个tensor的值(返回number), 用.item()
2. 创建scalar的话,需要用torch.tensor(number)
3.torch.tensor(list)也可以进行创建tensor

累加loss

以前了累加loss(为了看loss的大小)一般是用total_loss+=loss.data[0], 比较诡异的是, 为啥是.data[0]? 这是因为, 这是因为loss是一个Variable, 所以以后累加loss, 用loss.item().
这个是必须的, 如果直接加, 那么随着训练的进行, 会导致后来的loss具有非常大的graph, 可能会超内存. 然而total_loss只是用来看的, 所以没必要进行维持这个graph!

弃用volatile

现在这个flag已经没用了. 被替换成torch.no_grad(),torch.set_grad_enable(grad_mode)等函数

>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>>
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...   y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

dypes,devices以及numpy-style的构造函数

dtype是data types, 对应关系如下:

详解pytorch 0.4.0迁移指南

通过.dtype可以得到

其他就是以前写device type都是用.cup()或是.cuda(), 现在独立成一个函数, 我们可以

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
    [ 0.8414, 1.7962, 1.0589],
    [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

新的创建Tensor方法

主要是可以指定dtype以及device.

>>> device = torch.device("cuda:1")
>>> x = torch.randn(3, 3, dtype=torch.float64, device=device)
tensor([[-0.6344, 0.8562, -1.2758],
    [ 0.8414, 1.7962, 1.0589],
    [-0.1369, -1.0462, -0.4373]], dtype=torch.float64, device='cuda:1')
>>> x.requires_grad # default is False
False
>>> x = torch.zeros(3, requires_grad=True)
>>> x.requires_grad
True

用 torch.tensor来创建Tensor

这个等价于numpy.array,用途:
1.将python list的数据用来创建Tensor
2. 创建scalar

# 从列表中, 创建tensor
>>> cuda = torch.device("cuda")
>>> torch.tensor([[1], [2], [3]], dtype=torch.half, device=cuda)
tensor([[ 1],
    [ 2],
    [ 3]], device='cuda:0')

>>> torch.tensor(1)        # 创建scalar
tensor(1)

torch.*like以及torch.new_*

第一个是可以创建, shape相同, 数据类型相同.

>>> x = torch.randn(3, dtype=torch.float64)
 >>> torch.zeros_like(x)
 tensor([ 0., 0., 0.], dtype=torch.float64)
 >>> torch.zeros_like(x, dtype=torch.int)
 tensor([ 0, 0, 0], dtype=torch.int32)

当然如果是单纯想要得到属性与前者相同的Tensor, 但是shape不想要一致:

>>> x = torch.randn(3, dtype=torch.float64)
 >>> x.new_ones(2) # 属性一致
 tensor([ 1., 1.], dtype=torch.float64)
 >>> x.new_ones(4, dtype=torch.int)
 tensor([ 1, 1, 1, 1], dtype=torch.int32)

书写 device-agnostic 的代码

这个含义是, 不要显示的指定是gpu, cpu之类的. 利用.to()来执行.

# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)

迁移代码对比

以前的写法

model = MyRNN()
 if use_cuda:
   model = model.cuda()

 # train
 total_loss = 0
 for input, target in train_loader:
   input, target = Variable(input), Variable(target)
   hidden = Variable(torch.zeros(*h_shape)) # init hidden
   if use_cuda:
     input, target, hidden = input.cuda(), target.cuda(), hidden.cuda()
   ... # get loss and optimize
   total_loss += loss.data[0]

 # evaluate
 for input, target in test_loader:
   input = Variable(input, volatile=True)
   if use_cuda:
     ...
   ...

现在的写法

# torch.device object used throughout this script
 device = torch.device("cuda" if use_cuda else "cpu")

 model = MyRNN().to(device)

 # train
 total_loss = 0
 for input, target in train_loader:
   input, target = input.to(device), target.to(device)
   hidden = input.new_zeros(*h_shape) # has the same device & dtype as `input`
   ... # get loss and optimize
   total_loss += loss.item()      # get Python number from 1-element Tensor

 # evaluate
 with torch.no_grad():          # operations inside don't track history
   for input, target in test_loader:
     ...

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python之配置日志的几种方式
May 22 Python
python 垃圾收集机制的实例详解
Aug 20 Python
使用Python读取大文件的方法
Feb 11 Python
Python中的pack和unpack的使用
Mar 12 Python
利用python打开摄像头及颜色检测方法
Aug 03 Python
python hbase读取数据发送kafka的方法
Dec 27 Python
十个Python练手的实战项目,学会这些Python就基本没问题了(推荐)
Apr 26 Python
Python基础学习之基本数据结构详解【数字、字符串、列表、元组、集合、字典】
Jun 18 Python
pyqt5让图片自适应QLabel大小上以及移除已显示的图片方法
Jun 21 Python
Django框架教程之中间件MiddleWare浅析
Dec 29 Python
Python telnet登陆功能实现代码
Apr 16 Python
python numpy中setdiff1d的用法说明
Apr 22 Python
对pyqt5多线程正确的开启姿势详解
Jun 14 #Python
Python+PyQT5的子线程更新UI界面的实例
Jun 14 #Python
在PYQT5中QscrollArea(滚动条)的使用方法
Jun 14 #Python
PYQT5设置textEdit自动滚屏的方法
Jun 14 #Python
使用PyQt4 设置TextEdit背景的方法
Jun 14 #Python
Ubuntu18.04中Python2.7与Python3.6环境切换
Jun 14 #Python
ubuntu 16.04下python版本切换的方法
Jun 14 #Python
You might like
Yii框架连接mongodb数据库的代码
2016/07/27 PHP
TP5框架model常见操作示例小结【增删改查、聚合、时间戳、软删除等】
2020/04/05 PHP
用javascript获得地址栏参数的两种方法
2006/11/08 Javascript
精解window.setTimeout()&amp;window.setInterval()使用方式与参数传递问题!
2007/11/23 Javascript
番茄的表单验证类代码修改版
2008/07/18 Javascript
jQuery AJAX 调用WebService实现代码
2010/03/24 Javascript
JAVASCRIPT车架号识别/验证函数代码 汽车车架号验证程序
2012/01/08 Javascript
javascript中的取反再取反~~没有意义
2014/04/06 Javascript
JS判断是否360安全浏览器极速内核的方法
2015/01/29 Javascript
JavaScript阻止浏览器返回按钮的方法
2015/03/18 Javascript
jQuery的Scrollify插件实现滑动到页面下一节点
2015/07/05 Javascript
5种JavaScript脚本加载的方式
2017/01/16 Javascript
Angular 4.x中表单Reactive Forms详解
2017/04/25 Javascript
Vue自定义指令封装节流函数的方法示例
2018/07/09 Javascript
js中的闭包实例展示
2018/11/01 Javascript
基于mpvue小程序使用echarts画折线图的方法示例
2019/04/24 Javascript
swiper自定义分页器的样式
2020/09/14 Javascript
python smtplib模块自动收发邮件功能(一)
2018/05/22 Python
可能是最全面的 Python 字符串拼接总结【收藏】
2018/07/09 Python
Python中join()函数多种操作代码实例
2020/01/13 Python
使用python批量修改XML文件中图像的depth值
2020/07/22 Python
详解Python的爬虫框架 Scrapy
2020/08/03 Python
python实现图片转换成素描和漫画格式
2020/08/19 Python
python中添加模块导入路径的方法
2021/02/03 Python
涂鸦板简单实现 Html5编写属于自己的画画板
2016/07/05 HTML / CSS
诗普兰迪官方网站:Splendid
2018/09/18 全球购物
英国领先的豪华时尚家居网上商店:Amara
2019/08/12 全球购物
西部世纪.net笔试题面试题
2014/04/03 面试题
介绍一下.NET构架下remoting和webservice
2014/05/08 面试题
药学专业个人自我评价
2013/11/11 职场文书
国旗下的演讲稿
2014/05/08 职场文书
小学公民道德宣传日活动总结
2015/03/23 职场文书
小学总务工作总结
2015/08/13 职场文书
Go 语言中 20 个占位符的整理
2021/10/16 Golang
python机器学习Github已达8.9Kstars模型解释器LIME
2021/11/23 Python
docker compose 部署 golang 的 Athens 私有代理问题
2022/04/28 Servers