pytorch 共享参数的示例


Posted in Python onAugust 17, 2019

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))
 
  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))
 
  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())
 
    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))
 
  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
教你用python3根据关键词爬取百度百科的内容
Aug 18 Python
Python利用Beautiful Soup模块搜索内容详解
Mar 29 Python
解决Python 爬虫URL中存在中文或特殊符号无法请求的问题
May 11 Python
Python自动发送邮件的方法实例总结
Dec 08 Python
GitHub 热门:Python 算法大全,Star 超过 2 万
Apr 29 Python
Python发送邮件的实例代码讲解
Oct 16 Python
使用pandas实现连续数据的离散化处理方式(分箱操作)
Nov 22 Python
解析PyCharm Python运行权限问题
Jan 08 Python
Tensorflow实现部分参数梯度更新操作
Jan 23 Python
Python日志logging模块功能与用法详解
Apr 09 Python
Django ORM判断查询结果是否为空,判断django中的orm为空实例
Jul 09 Python
Python实现DBSCAN聚类算法并样例测试
Jun 22 Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
You might like
PHP在字符断点处截断文字的实现代码
2011/04/21 PHP
PHP基础教程(php入门基础教程)一些code代码
2013/01/06 PHP
PHP读取CSV大文件导入数据库的实例
2017/07/24 PHP
载入进度条 效果
2006/07/08 Javascript
csdn 博客的css样式 v3
2009/02/24 Javascript
JavaScript与Image加载事件(onload)、加载状态(complete)
2011/02/14 Javascript
JavaScript面向对象设计二 构造函数模式
2011/12/20 Javascript
JavaScript数组前面插入元素的方法
2015/04/06 Javascript
js+css实现有立体感的按钮式文字竖排菜单效果
2015/09/01 Javascript
jQuery Form 表单提交插件之formSerialize,fieldSerialize,fieldValue,resetForm,clearForm,clearFields的应用
2016/01/23 Javascript
jQuery mobile的header和footer在点击屏幕的时候消失的解决办法
2016/07/01 Javascript
详解Javascript ES6中的箭头函数(Arrow Functions)
2016/08/24 Javascript
再谈javascript注入 黑客必备!
2016/09/14 Javascript
JavaScript实现的冒泡排序法及统计相邻数交换次数示例
2017/04/26 Javascript
JavaScript模拟实现自由落体效果
2018/08/28 Javascript
element-ui 的el-button组件中添加自定义颜色和图标的实现方法
2018/10/26 Javascript
layer更改皮肤的实现方法
2019/09/11 Javascript
JavaScript获取页面元素的常用方法详解
2019/09/28 Javascript
使用Vue调取接口,并渲染数据的示例代码
2019/10/28 Javascript
JavaScript实现指定数量的并发限制的示例代码
2020/03/10 Javascript
Node在Controller层进行数据校验的过程详解
2020/08/28 Javascript
[02:57]DOTA2亚洲邀请赛 SECRET战队出场宣传片
2015/02/07 DOTA
python模块restful使用方法实例
2013/12/10 Python
Python错误提示:[Errno 24] Too many open files的分析与解决
2017/02/16 Python
用python实现将数组元素按从小到大的顺序排列方法
2018/07/02 Python
基于Python对数据shape的常见操作详解
2018/12/25 Python
python3 深浅copy对比详解
2019/08/12 Python
python实现马丁策略回测3000只股票的实例代码
2021/01/22 Python
HTML5 progress和meter控件_动力节点Java学院整理
2017/07/06 HTML / CSS
canvas实现圆形进度条动画的示例代码
2017/12/26 HTML / CSS
一套C++笔试题面试题
2012/06/06 面试题
应用心理学个人的求职信
2013/12/08 职场文书
教师实习自我鉴定
2013/12/18 职场文书
中学生的1000字检讨书
2014/10/11 职场文书
2014年党的群众路线活动个人整改措施
2014/10/28 职场文书
怎样写家长意见
2015/06/04 职场文书