Pytorch之Variable的用法


Posted in Python onDecember 31, 2019

1.简介

torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

Variable和tensor的区别和联系

Variable是篮子,而tensor是鸡蛋,鸡蛋应该放在篮子里才能方便拿走(定义variable时一个参数就是tensor)

Variable这个篮子里除了装了tensor外还有requires_grad参数,表示是否需要对其求导,默认为False

Variable这个篮子呢,自身有一些属性

比如grad,梯度variable.grad是d(y)/d(variable)保存的是变量y对variable变量的梯度值,如果requires_grad参数为False,所以variable.grad返回值为None,如果为True,返回值就为对variable的梯度值

比如grad_fn,对于用户自己创建的变量(Variable())grad_fn是为none的,也就是不能调用backward函数,但对于由计算生成的变量,如果存在一个生成中间变量的requires_grad为true,那其的grad_fn不为none,反则为none

比如data,这个就很简单,这个属性就是装的鸡蛋(tensor)

Varibale包含三个属性:

data:存储了Tensor,是本体的数据 grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致 grad_fn:指向Function对象,用于反向传播的梯度计算之用

代码1

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
y = x + temp + 2
y = y.mean() #求平均数
 
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

输出1

none

(因为requires_grad=False)

代码2

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
 
y = x + temp + 2
y = y.mean() #求平均数
 
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(temp.grad) # d(y)/d(temp)

输出2

tensor([[0.2500, 0.2500],
[0.2500, 0.2500]])

代码3

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
 
y = x + 2
y = y.mean() #求平均数
 
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)

输出3

Traceback (most recent call last):
File "path", line 12, in <module>
y.backward()

(报错了,因为生成变量y的中间变量只有x,而x的requires_grad是False,所以y的grad_fn是none)

代码4

import numpy as np
import torch
from torch.autograd import Variable
 
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
 
 
y = x + 2
y = y.mean() #求平均数
 
#y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(y.grad_fn) # d(y)/d(x)

输出4

none

2.grad属性

在每次backward后,grad值是会累加的,所以利用BP算法,每次迭代是需要将grad清零的。

x.grad.data.zero_()

(in-place操作需要加上_,即zero_)

以上这篇Pytorch之Variable的用法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python字符串加密解密的三种方法分享(base64 win32com)
Jan 19 Python
Python实现提取谷歌音乐搜索结果的方法
Jul 10 Python
python Django批量导入数据
Mar 25 Python
使用Python读写及压缩和解压缩文件的示例
Jul 08 Python
Python 专题二 条件语句和循环语句的基础知识
Mar 19 Python
python使用fork实现守护进程的方法
Nov 16 Python
使用Python搭建虚拟环境的配置方法
Feb 28 Python
Jupyter安装nbextensions,启动提示没有nbextensions库
Apr 23 Python
10分钟教你用Python实现微信自动回复功能
Nov 28 Python
python使用knn实现特征向量分类
Dec 26 Python
Python微医挂号网医生数据抓取
Jan 24 Python
使用Python快乐学数学Github万星神器Manim简介
Aug 07 Python
Pytorch 多块GPU的使用详解
Dec 31 #Python
Pyorch之numpy与torch之间相互转换方式
Dec 31 #Python
pytorch sampler对数据进行采样的实现
Dec 31 #Python
关于pytorch处理类别不平衡的问题
Dec 31 #Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
You might like
自制短波长线天线频率预选器 - 成功消除B2K之流的镜像
2021/03/02 无线电
php smarty模版引擎中的缓存应用
2009/12/02 PHP
php安全开发 添加随机字符串验证,防止伪造跨站请求
2013/02/14 PHP
php实现Linux服务器木马排查及加固功能
2014/12/29 PHP
调试WordPress中定时任务的相关PHP脚本示例
2015/12/10 PHP
深入理解php printf() 输出格式化的字符串
2016/05/23 PHP
PHP使用openssl扩展实现加解密方法示例
2020/02/20 PHP
收藏Javascript中常用的55个经典技巧
2007/08/12 Javascript
将json当数据库一样操作的javascript lib
2013/10/28 Javascript
JavaScript获取当前页面上的指定对象示例代码
2014/02/28 Javascript
Bootstrap3 Grid system原理及应用详解
2016/09/30 Javascript
jQuery autoComplete插件两种使用方式及动态改变参数值的方法详解
2016/10/24 Javascript
简单实现js菜单栏切换效果
2017/03/04 Javascript
vue v-on监听事件详解
2017/05/17 Javascript
jquery Form轻松实现文件上传
2017/05/24 jQuery
Angular js 实现添加用户、修改密码、敏感字、下拉菜单的综合操作方法
2017/10/24 Javascript
vue实现固定位置显示功能
2019/05/30 Javascript
独立部署小程序基于nodejs的服务器过程详解
2019/06/24 NodeJs
vue中h5端打开app(判断是安卓还是苹果)
2021/02/26 Vue.js
Python Socket传输文件示例
2017/01/16 Python
Python实现二维数组按照某行或列排序的方法【numpy lexsort】
2017/09/22 Python
Python实现控制台中的进度条功能代码
2017/12/22 Python
Python遍历pandas数据方法总结
2018/02/09 Python
python实现对文件中图片生成带标签的txt文件方法
2018/04/27 Python
django1.11.1 models 数据库同步方法
2018/05/30 Python
python将txt等文件中的数据读为numpy数组的方法
2018/12/22 Python
Python3多线程版TCP端口扫描器
2019/08/31 Python
Python pip 安装与使用(安装、更新、删除)
2019/10/06 Python
Pytorch maxpool的ceil_mode用法
2020/02/18 Python
Python如何重新加载模块
2020/07/29 Python
Canvas 像素处理之改变透明度的实现代码
2019/01/08 HTML / CSS
如何使用amaze ui的分页样式封装一个通用的JS分页控件
2020/08/21 HTML / CSS
网络、C以及其他硬件方面的面试题
2016/08/23 面试题
保密工作实施方案
2014/02/24 职场文书
公证处委托书
2015/01/28 职场文书
四十年同学聚会致辞
2015/07/28 职场文书