Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
解决Python print输出不换行没空格的问题
Nov 14 Python
python 实现调用子文件下的模块方法
Dec 07 Python
Python实现的旋转数组功能算法示例
Feb 23 Python
pandas 数据结构之Series的使用方法
Jun 21 Python
python KNN算法实现鸢尾花数据集分类
Oct 24 Python
python 协程中的迭代器,生成器原理及应用实例详解
Oct 28 Python
python 求10个数的平均数实例
Dec 16 Python
Python如何使用turtle库绘制图形
Feb 26 Python
python模拟斗地主发牌
Apr 22 Python
Pytorch转keras的有效方法,以FlowNet为例讲解
May 26 Python
windows安装python超详细图文教程
May 21 Python
python爬虫之selenium库的安装及使用教程
May 23 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
PHP安装问题
2006/10/09 PHP
php知道与问问的采集插件代码
2010/10/12 PHP
Ajax+PHP快速上手及简单应用说明
2013/07/24 PHP
详解PHP中array_rand函数的使用方法
2016/09/11 PHP
PHP封装cURL工具类与应用示例
2019/07/01 PHP
javascript 建设银行登陆键盘
2008/06/10 Javascript
imgAreaSelect 中文文档帮助说明
2011/10/08 Javascript
在标题栏显示新消息提示,很多公司项目中用到这个方法
2011/11/04 Javascript
Javascript计算两个marker之间的距离(Google Map V3)
2013/04/26 Javascript
document.forms[].submit()使用介绍
2014/02/19 Javascript
js改变鼠标的形状和样式的方法
2014/03/31 Javascript
jQuery filter函数使用方法
2014/05/19 Javascript
js实现每日自动换一张图片的方法
2015/05/04 Javascript
jquery事件的ready()方法使用详解
2015/11/11 Javascript
超精准的javascript验证身份证号的具体实现方法
2015/11/18 Javascript
JavaScript统计网站访问次数的实现代码
2015/11/18 Javascript
老生常谈javascript的类型转换
2016/10/12 Javascript
javascript设计模式之单体模式学习笔记
2017/02/15 Javascript
基于js中style.width与offsetWidth的区别(详解)
2017/11/12 Javascript
在vue中实现echarts随窗体变化
2020/07/27 Javascript
js实现微信聊天界面
2020/08/09 Javascript
简单介绍Python的Django框架加载模版的方式
2015/07/20 Python
Python模块WSGI使用详解
2018/02/02 Python
Django实战之用户认证(初始配置)
2018/07/16 Python
pandas 数据归一化以及行删除例程的方法
2018/11/10 Python
python虚拟环境的安装和配置(virtualenv,virtualenvwrapper)
2019/08/09 Python
Python Tornado之跨域请求与Options请求方式
2020/03/28 Python
在pycharm中debug 实时查看数据操作(交互式)
2020/06/09 Python
英国儿童图书网站:Scholastic
2017/03/26 全球购物
全球最大最受欢迎的旅游社区:Tripadvisor
2017/11/03 全球购物
英国马莎百货印度官网:Marks & Spencer印度
2020/10/08 全球购物
丝芙兰墨西哥官网:Sephora墨西哥
2020/05/30 全球购物
班组长安全生产职责
2013/12/16 职场文书
党校培训自我鉴定范文
2014/03/20 职场文书
幼儿园毕业致辞
2015/07/29 职场文书
php字符串倒叙
2021/04/01 PHP