pytorch 共享参数的示例


Posted in Python onAugust 17, 2019

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))
 
  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))
 
  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())
 
    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))
 
  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python字符串的encode与decode研究心得乱码问题解决方法
Mar 23 Python
python实现2014火车票查询代码分享
Jan 10 Python
简单介绍Python中的readline()方法的使用
May 24 Python
Python中使用Queue和Condition进行线程同步的方法
Jan 19 Python
Python闭包执行时值的传递方式实例分析
Jun 04 Python
利用pandas进行大文件计数处理的方法
Jul 25 Python
Django数据库连接丢失问题的解决方法
Dec 29 Python
Python之时间和日期使用小结
Feb 14 Python
使用Pandas对数据进行筛选和排序的实现
Jul 29 Python
python实现梯度下降法
Mar 24 Python
numpy库ndarray多维数组的维度变换方法(reshape、resize、swapaxes、flatten)
Apr 28 Python
Python直接赋值及深浅拷贝原理详解
Sep 05 Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
You might like
mysql 性能的检查和优化方法
2009/06/21 PHP
PHP5中使用PDO连接数据库的方法
2010/08/01 PHP
PHP调用VC编写的COM组件实例
2014/03/29 PHP
phpmyadmin配置文件现在需要绝密的短密码(blowfish_secret)的2种解决方法
2014/05/07 PHP
Joomla实现组件中弹出一个模式(modal)窗口的方法
2016/05/04 PHP
BOOM vs RR BO5 第三场 2.14
2021/03/10 DOTA
动态加载js的几种方法
2006/10/23 Javascript
JavaScript中的History历史对象
2008/01/16 Javascript
javascript 语法基础 想学习js的朋友可以看看
2009/12/16 Javascript
jquery实现的超出屏幕时把固定层变为定位层的代码
2010/02/23 Javascript
js 多种变量定义(对象直接量,数组直接量和函数直接量)
2010/05/24 Javascript
Jquery颜色选择器ColorPicker实现代码
2012/11/14 Javascript
AngularJs根据访问的页面动态加载Controller的解决方案
2015/02/04 Javascript
JavaScript操作XML文件之XML读取方法
2015/06/09 Javascript
Angular.js实现注册系统的实例详解
2016/12/18 Javascript
JavaScript如何获取到导航条中HTTP信息
2017/10/10 Javascript
JS+WCF实现进度条实时监测数据加载量的方法详解
2017/12/19 Javascript
微信小程序block的使用教程
2018/04/01 Javascript
基于JavaScript或jQuery实现网站夜间/高亮模式
2020/05/30 jQuery
[03:03]2014DOTA2国际邀请赛 EG战队专访
2014/07/12 DOTA
Python实现读取txt文件并画三维图简单代码示例
2017/12/09 Python
python实现几种归一化方法(Normalization Method)
2019/07/31 Python
Python分割训练集和测试集的方法示例
2019/09/19 Python
python使用 __init__初始化操作简单示例
2019/09/26 Python
python的pyecharts绘制各种图表详细(附代码)
2019/11/11 Python
Python图像处理库PIL的ImageFont模块使用介绍
2020/02/26 Python
Python爬虫进阶之爬取某视频并下载的实现
2020/12/08 Python
python上下文管理器异常问题解决方法
2021/02/07 Python
野兽派官方旗舰店:THE BEAST 野兽派
2016/08/05 全球购物
澳大利亚冒险体验:Adrenaline(跳伞、V8赛车、热气球等)
2017/09/18 全球购物
计算机专业自荐信
2013/10/14 职场文书
教师个人的自我评价分享
2014/01/02 职场文书
小学清明节活动方案
2014/03/08 职场文书
电钳工人个人求职信
2014/05/10 职场文书
分析并发编程之LongAdder原理
2021/06/29 Java/Android
django中websocket的具体使用
2022/01/22 Python