Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python字符串加密解密的三种方法分享(base64 win32com)
Jan 19 Python
python进阶教程之循环相关函数range、enumerate、zip
Aug 30 Python
Python中实现两个字典(dict)合并的方法
Sep 23 Python
python端口扫描系统实现方法
Nov 19 Python
Fabric 应用案例
Aug 28 Python
Python爬虫包 BeautifulSoup  递归抓取实例详解
Jan 28 Python
Python爬虫获取整个站点中的所有外部链接代码示例
Dec 26 Python
python3解析库lxml的安装与基本使用
Jun 27 Python
Python 中的 import 机制之实现远程导入模块
Oct 29 Python
python写一个随机点名软件的实例
Nov 28 Python
python程序需要编译吗
Jun 19 Python
python 爬虫请求模块requests详解
Dec 04 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
php FLEA中二叉树数组的遍历输出
2012/09/26 PHP
PHP读取大文件的类SplFileObject使用介绍
2014/04/09 PHP
Zend Framework入门教程之Zend_Registry组件用法详解
2016/12/09 PHP
PHP判断一个数组是另一个数组子集的方法详解
2017/07/31 PHP
PHP生成二维码与识别二维码的方法详解【附源码下载】
2019/03/07 PHP
使用jQuery全局事件ajaxStart为特定请求实现提示效果的代码
2010/12/30 Javascript
$.ajax返回的JSON无法执行success的解决方法
2011/09/09 Javascript
『jQuery』取指定url格式及分割函数应用
2013/04/22 Javascript
文档对象模型DOM通俗讲解
2013/11/01 Javascript
javascript垃圾收集机制与内存泄漏详细解析
2013/11/11 Javascript
js几秒以后倒计时跳转示例
2013/12/26 Javascript
关于javaScript注册click事件传递参数的不成功问题
2014/07/18 Javascript
JavaScript字符串对象substr方法入门实例(用于截取字符串)
2014/10/16 Javascript
jQuery中detach()方法用法实例
2014/12/25 Javascript
jQuery中toggle()函数的使用实例
2015/04/17 Javascript
Jquery ajax 同步阻塞引起的UI线程阻塞问题
2015/11/17 Javascript
拥Bootstrap入怀——导航栏篇
2016/05/30 Javascript
两种简单的跨域方法(jsonp、php)
2017/01/02 Javascript
js实现随机数字字母验证码
2017/06/19 Javascript
js实现图片轮播效果学习笔记
2017/07/26 Javascript
React中jquery引用的实现方法
2017/09/12 jQuery
nodejs操作mongodb的增删改查功能实例
2017/11/09 NodeJs
vue.js实现插入数值与表达式的方法分析
2018/07/06 Javascript
node.js ws模块搭建websocket服务端的方法示例
2019/04/25 Javascript
JS实现购物车基本功能
2020/11/08 Javascript
vue中defineProperty和Proxy的区别详解
2020/11/30 Vue.js
python实现从web抓取文档的方法
2014/09/26 Python
python将txt文件读入为np.array的方法
2018/10/30 Python
python实现五子棋人机对战游戏
2020/03/25 Python
如何通过50行Python代码获取公众号全部文章
2019/07/12 Python
python torch.utils.data.DataLoader使用方法
2020/04/02 Python
CSS3中的clip-path使用攻略
2015/08/03 HTML / CSS
Philosophy美国官网:美国美容品牌
2016/08/15 全球购物
英文版餐饮业求职信
2013/10/18 职场文书
2014年幼儿园后勤工作总结
2014/11/10 职场文书
监护人证明
2015/06/19 职场文书