Pytorch之Variable的用法


Posted in Python onDecember 31, 2019

1.简介

torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

Variable和tensor的区别和联系

Variable是篮子,而tensor是鸡蛋,鸡蛋应该放在篮子里才能方便拿走(定义variable时一个参数就是tensor)

Variable这个篮子里除了装了tensor外还有requires_grad参数,表示是否需要对其求导,默认为False

Variable这个篮子呢,自身有一些属性

比如grad,梯度variable.grad是d(y)/d(variable)保存的是变量y对variable变量的梯度值,如果requires_grad参数为False,所以variable.grad返回值为None,如果为True,返回值就为对variable的梯度值

比如grad_fn,对于用户自己创建的变量(Variable())grad_fn是为none的,也就是不能调用backward函数,但对于由计算生成的变量,如果存在一个生成中间变量的requires_grad为true,那其的grad_fn不为none,反则为none

比如data,这个就很简单,这个属性就是装的鸡蛋(tensor)

Varibale包含三个属性:

data:存储了Tensor,是本体的数据 grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致 grad_fn:指向Function对象,用于反向传播的梯度计算之用

代码1

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
y = x + temp + 2
y = y.mean() #求平均数
 
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

输出1

none

(因为requires_grad=False)

代码2

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
 
y = x + temp + 2
y = y.mean() #求平均数
 
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(temp.grad) # d(y)/d(temp)

输出2

tensor([[0.2500, 0.2500],
[0.2500, 0.2500]])

代码3

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
 
y = x + 2
y = y.mean() #求平均数
 
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

输出3

Traceback (most recent call last):
File "path", line 12, in <module>
y.backward()

(报错了,因为生成变量y的中间变量只有x,而x的requires_grad是False,所以y的grad_fn是none)

代码4

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
 
y = x + 2
y = y.mean() #求平均数
 
#y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(y.grad_fn) # d(y)/d(x)

输出4

none

2.grad属性

在每次backward后,grad值是会累加的,所以利用BP算法,每次迭代是需要将grad清零的。

x.grad.data.zero_()

(in-place操作需要加上_,即zero_)

以上这篇Pytorch之Variable的用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现博客文章爬虫示例
Feb 26 Python
Python2.6版本中实现字典推导 PEP 274(Dict Comprehensions)
Apr 28 Python
Python实现的简单线性回归算法实例分析
Dec 26 Python
Python3.5实现的罗马数字转换成整数功能示例
Feb 25 Python
详解python爬虫系列之初识爬虫
Apr 06 Python
在Pycharm中使用GitHub的方法步骤
Jun 13 Python
Python创建或生成列表的操作方法
Jun 19 Python
python opencv 简单阈值算法的实现
Aug 04 Python
Django 源码WSGI剖析过程详解
Aug 05 Python
python代码实现TSNE降维数据可视化教程
Feb 28 Python
python线程优先级队列知识点总结
Feb 28 Python
使用Django框架创建项目
Jun 10 Python
Pytorch 多块GPU的使用详解
Dec 31 #Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
You might like
PHP中使用循环实现的金字塔图形
2014/11/08 PHP
yii2 RBAC使用DbManager实现后台权限判断的方法
2016/07/23 PHP
PHP中命名空间的使用例子
2019/03/22 PHP
jQuery 渐变下拉菜单
2009/12/15 Javascript
JSQL  一个 web DB 的封装
2010/05/05 Javascript
js屏蔽鼠标键盘(右键/Ctrl+N/Shift+F10/F11/F5刷新/退格键)
2013/01/24 Javascript
让jQuery与其他JavaScript库并存避免冲突的方法
2013/12/23 Javascript
jQuery基于ID调用指定iframe页面内的方法
2016/07/06 Javascript
angularJS Provider、factory、service详解及实例代码
2016/09/21 Javascript
js前端解决跨域问题的8种方案(最新最全)
2016/11/18 Javascript
ES6中参数的默认值语法介绍
2017/05/03 Javascript
浅谈express 中间件机制及实现原理
2017/08/31 Javascript
Js利用console计算代码运行时间的方法示例
2017/09/24 Javascript
基于VUE.JS的移动端框架Mint UI的使用
2017/10/11 Javascript
JS+php后台实现文件上传功能详解
2019/03/02 Javascript
Angular 中使用 FineReport不显示报表直接打印预览
2019/08/21 Javascript
vue中的双向数据绑定原理与常见操作技巧详解
2020/03/16 Javascript
详解Django中的ifequal和ifnotequal标签使用
2015/07/16 Python
django轻松使用富文本编辑器CKEditor的方法
2017/03/30 Python
python基础教程项目四之新闻聚合
2018/04/02 Python
Python面向对象程序设计类的多态用法详解
2019/04/12 Python
Django框架实现分页显示内容的方法详解
2019/05/10 Python
Python3 pandas 操作列表实例详解
2019/09/23 Python
关于numpy数组轴的使用详解
2019/12/05 Python
linux环境下安装python虚拟环境及注意事项
2020/01/07 Python
python中lower函数实现方法及用法讲解
2020/12/23 Python
几道数据库的面试题或笔试题
2014/05/31 面试题
解释DataSet(ds) 和 ds as DataSet 的含义
2014/07/27 面试题
几道Java和数据库的面试题
2013/05/30 面试题
通信工程专业毕业生推荐信
2013/12/25 职场文书
中层竞聘演讲稿
2014/01/09 职场文书
反对四风自我剖析材料
2014/10/07 职场文书
党员批评与自我批评思想汇报
2014/10/08 职场文书
2015年扫黄打非工作总结
2015/05/13 职场文书
中秋节主题班会
2015/08/14 职场文书
青年干部培训班学习心得体会
2016/01/06 职场文书