pytorch 共享参数的示例


Posted in Python onAugust 17, 2019

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))
 
  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))
 
  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())
 
    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))
 
  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Socket编程详细介绍
Mar 23 Python
python使用super()出现错误解决办法
Aug 14 Python
Python实现PS图像调整之对比度调整功能示例
Jan 26 Python
Ubuntu下使用python读取doc和docx文档的内容方法
May 08 Python
python学习笔记--将python源文件打包成exe文件(pyinstaller)
May 26 Python
python斐波那契数列的计算方法
Sep 27 Python
Django框架中间件定义与使用方法案例分析
Nov 28 Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 Python
关于Tensorflow使用CPU报错的解决方式
Feb 05 Python
Python getattr()函数使用方法代码实例
Aug 10 Python
C++和python实现阿姆斯特朗数字查找实例代码
Dec 07 Python
Appium中scroll和drag_and_drop根据元素位置滑动
Feb 15 Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
You might like
PHP通过COM使用ADODB的简单例子
2006/12/31 PHP
在html文件中也可以执行php语句的方法
2015/04/09 PHP
PHP超牛逼无限极分类生成树方法
2015/05/11 PHP
php版微信公众平台入门教程之开发者认证的方法
2016/09/26 PHP
Yii Framework框架使用PHPExcel组件的方法示例
2019/07/24 PHP
php设计模式之享元模式分析【星际争霸游戏案例】
2020/03/23 PHP
js封装的textarea操作方法集合(兼容很好)
2010/11/16 Javascript
扩展Jquery插件处理mouseover时内部有子元素时发生样式闪烁
2011/12/08 Javascript
多选列表框动态添加,移动,删除,全选等操作的简单实例
2014/01/13 Javascript
Jquery插件分享之气泡形提示控件grumble.js
2014/05/20 Javascript
javascript使用appendChild追加节点实例
2015/01/12 Javascript
Three.js的使用及绘制基础3D图形详解
2017/04/27 Javascript
详解Angular2响应式表单
2017/06/14 Javascript
jquery图片放大镜效果
2017/06/23 jQuery
使用Ajax和Jquery配合数据库实现下拉框的二级联动的示例
2018/01/25 jQuery
JavaScript笛卡尔积超简单实现算法示例
2018/07/30 Javascript
100行代码实现一个vue分页组功能
2018/11/06 Javascript
服务端预渲染之Nuxt(使用篇)
2019/04/08 Javascript
vue项目打包后上传至GitHub并实现github-pages的预览
2019/05/06 Javascript
[03:28]2014DOTA2国际邀请赛 走近EG战队天才中单Arteezy
2014/07/12 DOTA
[43:57]LGD vs Mineski 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
Python实现将一个正整数分解质因数的方法分析
2017/12/14 Python
Python Selenium 设置元素等待的三种方式
2020/03/18 Python
Python使用lambda抛出异常实现方法解析
2020/08/20 Python
利用python实现汉诺塔游戏
2021/03/01 Python
英国第一豪华护肤品牌:Elemis
2017/10/12 全球购物
销售行政专员职责
2014/01/03 职场文书
应届毕业生简历自我评价
2014/01/31 职场文书
致接力运动员广播稿
2014/02/17 职场文书
一年级班主任工作总结2014
2014/11/08 职场文书
专家推荐信怎么写
2015/03/25 职场文书
2016三八妇女节慰问信
2015/11/30 职场文书
PyTorch 如何设置随机数种子使结果可复现
2021/05/12 Python
浅谈JS的原型和原型链
2021/06/04 Javascript
如何设置多台电脑共享打印机?多台电脑共享打印机的方法
2022/04/08 数码科技
golang连接MySQl使用sqlx库
2022/04/14 Golang