Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
利用Python的装饰器解决Bottle框架中用户验证问题
Apr 24 Python
python 数据的清理行为实例详解
Jul 12 Python
分享几道你可能遇到的python面试题
Jul 24 Python
Python实现霍夫圆和椭圆变换代码详解
Jan 12 Python
python pandas.DataFrame选取、修改数据最好用.loc,.iloc,.ix实现
Jun 11 Python
Windows下Anaconda2安装NLTK教程
Sep 19 Python
Python通过Tesseract库实现文字识别
Mar 05 Python
python实现TCP文件传输
Mar 20 Python
python 给图像添加透明度(alpha通道)
Apr 09 Python
Python HTMLTestRunner如何下载生成报告
Sep 04 Python
通俗易懂了解Python装饰器原理
Sep 17 Python
Python之多进程与多线程的使用
Feb 23 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
php调用nginx的mod_zip模块打包ZIP文件
2014/06/11 PHP
ThinkPHP 3.2.2实现事务操作的方法
2017/05/05 PHP
PHP7原生MySQL数据库操作实现代码
2020/07/03 PHP
一个无限级XML绑定跨框架菜单(For IE)
2007/01/27 Javascript
网上抓的一个特效
2007/05/11 Javascript
jQuery开发者都需要知道的5个小技巧
2010/01/08 Javascript
JavaScript访问样式表代码
2010/10/15 Javascript
javascript学习笔记(五)正则表达式
2011/04/08 Javascript
jquery checkbox实现单选小例
2013/11/27 Javascript
浅谈javascript中的Function和Arguments
2016/08/30 Javascript
js导出excel文件的简洁方法(推荐)
2016/11/02 Javascript
js图片放大镜效果实现方法详解
2020/10/28 Javascript
JS switch判断 三目运算 while 及 属性操作代码
2017/09/03 Javascript
JS高阶函数原理与用法实例分析
2019/01/15 Javascript
《javascript设计模式》学习笔记三:Javascript面向对象程序设计单例模式原理与实现方法分析
2020/04/07 Javascript
vue中移动端调取本地的复制的文本方式
2020/07/18 Javascript
Python使用稀疏矩阵节省内存实例
2014/06/27 Python
itchat和matplotlib的结合使用爬取微信信息的实例
2017/08/25 Python
PyQt5打开文件对话框QFileDialog实例代码
2018/02/07 Python
python将四元数变换为旋转矩阵的实例
2019/12/04 Python
基于Tensorflow高阶读写教程
2020/02/10 Python
matlab中imadjust函数的作用及应用举例
2020/02/27 Python
python 基于opencv实现高斯平滑
2020/12/18 Python
使用HTML5 Canvas API中的clip()方法裁剪区域图像
2016/03/25 HTML / CSS
canvas实现圆绘制的示例代码
2019/09/11 HTML / CSS
HTML5教程之html 5 本地数据库(Web Sql Database)
2014/04/03 HTML / CSS
白宫黑市官网:White House Black Market
2016/11/17 全球购物
最好的意大利皮夹克:D’Arienzo
2018/12/04 全球购物
印度民族服装购物网站:BIBA
2019/08/05 全球购物
Shell如何接收变量输入
2012/09/24 面试题
外贸公司实习自我鉴定
2013/09/24 职场文书
爱国卫生月实施方案
2014/02/21 职场文书
分公司总经理岗位职责
2014/07/30 职场文书
2014年置业顾问工作总结
2014/11/17 职场文书
2014年学校总务处工作总结
2014/12/08 职场文书
Python 解决空列表.append() 输出为None的问题
2021/05/23 Python