Pytorch 使用tensor特定条件判断索引


Posted in Python onApril 08, 2021

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

Pytorch 使用tensor特定条件判断索引

想要实现numpy中where()的功能,可以借助nonzero()

Pytorch 使用tensor特定条件判断索引

对应numpy中的where()操作效果:

Pytorch 使用tensor特定条件判断索引

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

Pytorch 使用tensor特定条件判断索引

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

Pytorch 使用tensor特定条件判断索引

import torch as t
a = t.ones(10,)
b = a.detach()
print(b)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)
a = detach()
b = B(a)
loss = criterion(b, target)
loss.backward()

来看一个实际的例子:

import torch as t
x = t.ones(1, requires_grad=True)
x.requires_grad   #True
y = t.ones(1, requires_grad=True)
y.requires_grad   #True
x = x.detach()   #分离之后
x.requires_grad   #False
y = x+y         #tensor([2.])
y.requires_grad   #我还是True
y.retain_grad()   #y不是叶子张量,要加上这一行
z = t.pow(y, 2)
z.backward()    #反向传播
y.grad        #tensor([4.])
x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

?假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters():
 param.requires_grad = False
a = A(input)
b = B(a)
loss = criterion(b, target)
loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python网络编程学习笔记(六):Web客户端访问
Jun 09 Python
python设计模式大全
Jun 27 Python
老生常谈python之鸭子类和多态
Jun 13 Python
python实现推箱子游戏
Mar 25 Python
浅谈python实现Google翻译PDF,解决换行的问题
Nov 28 Python
Python切片操作去除字符串首尾的空格
Apr 22 Python
Python深拷贝与浅拷贝用法实例分析
May 05 Python
django admin后台添加导出excel功能示例代码
May 15 Python
opencv 获取rtsp流媒体视频的实现方法
Aug 23 Python
浅谈Pytorch torch.optim优化器个性化的使用
Feb 20 Python
python中setuptools的作用是什么
Jun 19 Python
MAC平台基于Python Appium环境搭建过程图解
Aug 13 Python
selenium.webdriver中add_argument方法常用参数表
Apr 08 #Python
python3使用diagrams绘制架构图的步骤
python实现求纯色彩图像的边框
python爬取企查查企业信息之selenium自动模拟登录企查查
Python3 使用pip安装git并获取Yahoo金融数据的操作
Apr 08 #Python
Django 如何实现文件上传下载
Apr 08 #Python
python3 删除所有自定义变量的操作
Apr 08 #Python
You might like
DISCUZ 论坛管理员密码忘记的解决方法
2009/05/14 PHP
第4章 数据处理-php数组的处理-郑阿奇
2011/07/04 PHP
php中使用PHPExcel读写excel(xls)文件的方法
2014/09/15 PHP
PHP中Header使用的HTTP协议及常用方法小结
2014/11/04 PHP
php+mysql结合Ajax实现点赞功能完整实例
2015/01/30 PHP
分享一则PHP定义函数代码
2015/02/26 PHP
Twig模板引擎用法入门教程
2016/01/20 PHP
laravel创建类似ThinPHP中functions.php的全局函数
2016/11/26 PHP
PHP异常类及异常处理操作实例详解
2018/12/19 PHP
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
模仿JQuery.extend函数扩展自己对象的js代码
2009/12/09 Javascript
javascript与CSS复习(二)
2010/06/29 Javascript
js截取函数(indexOf,join等)
2010/09/01 Javascript
在Python中使用glob模块查找文件路径的方法
2015/06/17 Javascript
JS实现漂亮的时间选择框效果
2016/08/20 Javascript
JavaScript排序算法动画演示效果的实现方法
2016/10/18 Javascript
详解基于 Nuxt 的 Vue.js 服务端渲染实践
2017/10/24 Javascript
PHP自动加载autoload和命名空间的应用小结
2017/12/01 Javascript
Vue.js+Layer表格数据绑定与实现更新的实例
2018/03/07 Javascript
javascript中call,apply,callee,caller用法实例分析
2019/07/24 Javascript
详解钉钉小程序组件之自定义模态框(弹窗封装实现)
2020/03/07 Javascript
如何将python中的List转化成dictionary
2016/08/15 Python
Python实现选择排序
2017/06/04 Python
使用python实现个性化词云的方法
2017/06/16 Python
Python3数字求和的实例
2019/02/19 Python
Python爬虫beautifulsoup4常用的解析方法总结
2019/02/25 Python
tensorflow 报错unitialized value的解决方法
2020/02/06 Python
Pytorch maxpool的ceil_mode用法
2020/02/18 Python
python 图像插值 最近邻、双线性、双三次实例
2020/07/05 Python
PIP和conda 更换国内安装源的方法步骤
2020/09/21 Python
小学教师国培感言
2014/02/08 职场文书
旷工检讨书1000字
2015/01/01 职场文书
社区国庆节活动总结
2015/03/23 职场文书
政协工作总结2015
2015/05/20 职场文书
mybatis调用sqlserver存储过程返回结果集的方法
2021/05/08 SQL Server
详解MySQL的Seconds_Behind_Master
2021/05/18 MySQL