Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
在Python中使用PIL模块对图片进行高斯模糊处理的教程
May 05 Python
python遍历 truple list dictionary的几种方法总结
Sep 11 Python
python numpy函数中的linspace创建等差数列详解
Oct 13 Python
微信跳一跳自动运行python脚本
Jan 08 Python
在Django中输出matplotlib生成的图片方法
May 24 Python
Python实现的个人所得税计算器示例
Jun 01 Python
使用pandas将numpy中的数组数据保存到csv文件的方法
Jun 14 Python
Python使用post及get方式提交数据的实例
Jan 24 Python
Python中的集合介绍
Jan 28 Python
kafka监控获取指定topic的消息总量示例
Dec 23 Python
Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式
Jan 10 Python
python 管理系统实现mysql交互的示例代码
Dec 06 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
检查用户名是否已在mysql中存在的php写法
2014/01/20 PHP
PhpDocumentor 2安装以及生成API文档的方法
2014/05/21 PHP
PHP统计目录大小的自定义函数分享
2014/11/18 PHP
yii框架无限极分类的实现方法
2017/04/08 PHP
PHP类的自动加载机制实现方法分析
2019/01/10 PHP
浅谈Laravel POST,PUT,PATCH 路由的区别
2019/10/15 PHP
一段多浏览器的"复制到剪贴板"javascript代码
2007/03/27 Javascript
js的表单操作 简单计算器
2011/12/29 Javascript
Extjs4 Treegrid 使用心得分享(经验篇)
2013/07/01 Javascript
JS简单实现元素复制示例附图
2013/11/19 Javascript
网页中表单按回车就自动提交的问题的解决方案
2014/11/03 Javascript
javascript使用avalon绑定实现checkbox全选
2015/05/06 Javascript
jquery实现加载进度条提示效果
2015/11/23 Javascript
浅谈几种常用的JS类定义方法
2016/06/08 Javascript
jquery radio的取值_radio的选中_radio的重置方法
2016/09/20 Javascript
jQuery Ajax请求后台数据并在前台接收
2016/12/10 Javascript
微信小程序 视图容器组件的详解及实例代码
2017/01/19 Javascript
AngularJS实现动态添加Option的方法
2017/05/17 Javascript
JS实现延迟隐藏功能的方法(类似QQ头像鼠标放上展示信息)
2017/12/28 Javascript
layui内置模块layim发送图片添加加载动画的方法
2019/09/23 Javascript
使用Typescript和ES模块发布Node模块的方法
2020/05/25 Javascript
[02:10]探秘浦东源深体育馆 DOTA2 Supermajor不见不散
2018/05/17 DOTA
Python中使用摄像头实现简单的延时摄影技术
2015/03/27 Python
Python新手入门最容易犯的错误总结
2017/04/24 Python
微信跳一跳python自动代码解读1.0
2018/01/12 Python
使用Python自动化破解自定义字体混淆信息的方法实例
2019/02/13 Python
详解Python读取yaml文件多层菜单
2019/03/23 Python
python cv2读取rtsp实时码流按时生成连续视频文件方式
2019/12/25 Python
奥兰多迪士尼门票折扣:Undercover Tourist
2018/07/09 全球购物
通信工程专业个人找工作求职信范文
2013/09/21 职场文书
生物科学专业个人求职信范文
2013/12/05 职场文书
写给女朋友的道歉信
2014/01/12 职场文书
工地宣传标语
2014/06/18 职场文书
python基础之//、/与%的区别详解
2022/06/10 Python
CSS 鼠标点击拖拽效果的实现代码
2022/12/24 HTML / CSS
详解MySQL的内连接和外连接
2023/05/08 MySQL