Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
浅析Python中signal包的使用
Nov 13 Python
python实现淘宝秒杀聚划算抢购自动提醒源码
Jun 23 Python
在Python中分别打印列表中的每一个元素方法
Nov 07 Python
使用python itchat包爬取微信好友头像形成矩形头像集的方法
Feb 21 Python
对Django项目中的ORM映射与模糊查询的使用详解
Jul 18 Python
使用Python代码实现Linux中的ls遍历目录命令的实例代码
Sep 07 Python
Pycharm+Python+PyQt5使用详解
Sep 25 Python
Window10下python3.7 安装与卸载教程图解
Sep 30 Python
Python如何批量获取文件夹的大小并保存
Mar 31 Python
Python读取配置文件(config.ini)以及写入配置文件
Apr 08 Python
查找适用于matplotlib的中文字体名称与实际文件名对应关系的方法
Jan 05 Python
Python 中如何使用 virtualenv 管理虚拟环境
Jan 21 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
PHP6 先修班 JSON实例代码
2008/08/23 PHP
网友原创的PHP模板类代码
2008/09/07 PHP
PHP截断标题且兼容utf8和gb2312编码
2013/09/22 PHP
PHP常用字符串操作函数实例总结(trim、nl2br、addcslashes、uudecode、md5等)
2016/01/09 PHP
基于PHP常用文件函数和目录函数整理
2017/08/17 PHP
基于jquery的地址栏射击游戏代码
2011/03/10 Javascript
Jquery 的扩展方法总结
2011/10/01 Javascript
javascript实现简单的Map示例介绍
2013/12/23 Javascript
javascript+HTML5 Canvas绘制转盘抽奖
2020/05/16 Javascript
Google 地图事件实例讲解
2016/08/06 Javascript
js中常用的Tab切换效果(推荐)
2016/08/30 Javascript
jquery实现简单的瀑布流布局
2016/12/11 Javascript
js实现图片上传预览原理分析
2017/07/13 Javascript
JS沙箱模式实例分析
2017/09/04 Javascript
快速将Vue项目升级到webpack3的方法步骤
2017/09/14 Javascript
详解javascript中的babel到底是什么
2018/06/21 Javascript
vue项目或网页上实现文字转换成语音播放功能
2020/06/09 Javascript
[02:16]DOTA2超级联赛专访Burning 逆袭需要抓住机会
2013/06/24 DOTA
[01:02:07]Liquid vs Newbee 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
python使用urllib模块和pyquery实现阿里巴巴排名查询
2014/01/16 Python
使用70行Python代码实现一个递归下降解析器的教程
2015/04/17 Python
浅谈python字符串方法的简单使用
2016/07/18 Python
python 遍历字符串(含汉字)实例详解
2017/04/04 Python
Python编程pygame模块实现移动的小车示例代码
2018/01/03 Python
python 读文件,然后转化为矩阵的实例
2018/04/23 Python
基于Python List的赋值方法
2018/06/23 Python
利用Python检测URL状态
2019/07/31 Python
对Django中的权限和分组管理实例讲解
2019/08/16 Python
Python判断三段线能否构成三角形的代码
2020/04/12 Python
python中pandas库中DataFrame对行和列的操作使用方法示例
2020/06/14 Python
python实现文件+参数发送request的实例代码
2021/01/05 Python
日本PLST在线商店:日本时尚杂志刊载的人气服装
2016/12/10 全球购物
巴西在线鞋店:Shoestock
2017/10/28 全球购物
Sarenza德国:法国最大的时尚鞋和包包网上商店
2019/06/08 全球购物
护士自我鉴定
2013/10/23 职场文书
Python打包exe时各种异常处理方案总结
2021/05/18 Python