pytorch 共享参数的示例


Posted in Python onAugust 17, 2019

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))
 
  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))
 
  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())
 
    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))
 
  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的高级Git库 Gittle
Sep 22 Python
python文件操作整理汇总
Oct 21 Python
python如何在终端里面显示一张图片
Aug 17 Python
python+pyqt实现右下角弹出框
Oct 26 Python
python使用socket创建tcp服务器和客户端
Apr 12 Python
Python3正则匹配re.split,re.finditer及re.findall函数用法详解
Jun 11 Python
深入理解Django-Signals信号量
Feb 19 Python
Python 学习教程之networkx
Apr 15 Python
python matplotlib拟合直线的实现
Nov 19 Python
执行Python程序时模块报错问题
Mar 26 Python
Python通过kerberos安全认证操作kafka方式
Jun 06 Python
Matlab如何实现矩阵复制扩充
Jun 02 Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
You might like
PHP 数据库树的遍历方法
2009/02/06 PHP
php处理json时中文问题的解决方法
2011/04/12 PHP
php调用方法mssql_fetch_row、mssql_fetch_array、mssql_fetch_assoc和mssql_fetch_objcect读取数据的区别
2012/08/08 PHP
php管理nginx虚拟主机shell脚本实例
2014/11/19 PHP
[原创]php使用curl判断网页404(不存在)的方法
2016/06/23 PHP
php处理抢购类功能的高并发请求
2018/02/08 PHP
JavaScript 异步调用框架 (Part 3 - 代码实现)
2009/08/04 Javascript
在网页中使用document.write时遭遇的奇怪问题
2010/08/24 Javascript
原生javascript兼容性测试实例
2013/07/01 Javascript
window.navigate 与 window.location.href 的使用区别介绍
2013/09/21 Javascript
jQuery 仿百度输入标签插件附效果图
2014/07/04 Javascript
Jquery仿IGoogle实现可拖动窗口示例代码
2014/08/22 Javascript
js省市联动效果完整实例代码
2015/12/09 Javascript
基于JavaScript实现移动端无限加载分页
2017/03/27 Javascript
基于Bootstrap分页的实例讲解(必看篇)
2017/07/04 Javascript
Vue插槽_特殊特性slot,slot-scope与指令v-slot说明
2020/09/04 Javascript
vue制作toast组件npm包示例代码
2020/10/29 Javascript
[00:53]2015国际邀请赛 中国区预选赛一触即发
2015/05/14 DOTA
python的id()函数解密过程
2012/12/25 Python
python函数返回多个值的示例方法
2013/12/04 Python
python基于urllib实现按照百度音乐分类下载mp3的方法
2015/05/25 Python
python递归打印某个目录的内容(实例讲解)
2017/08/30 Python
Python设计模式之MVC模式简单示例
2018/01/10 Python
Python爬虫实现抓取京东店铺信息及下载图片功能示例
2018/08/07 Python
python调用百度语音识别实现大音频文件语音识别功能
2018/08/30 Python
Python 自动登录淘宝并保存登录信息的方法
2019/09/04 Python
python numpy之np.random的随机数函数使用介绍
2019/10/06 Python
python输入一个水仙花数(三位数) 输出百位十位个位实例
2020/05/03 Python
HTML5中如何显示视频呢 HTML5视频播放demo
2013/06/08 HTML / CSS
澳大利亚在线百货商店:Real Smart
2017/08/13 全球购物
最畅销的视频游戏享受高达90%的折扣:CDKeys
2020/02/10 全球购物
介绍下Java中==和equals的区别
2013/09/01 面试题
韩语专业本科生求职信
2013/10/01 职场文书
社团活动总结
2014/04/28 职场文书
教师工作表现自我评价
2015/03/05 职场文书
分享7个 Python 实战项目练习
2022/03/03 Python