pytorch 共享参数的示例


Posted in Python onAugust 17, 2019

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))
 
  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))
 
  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())
 
    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))
 
  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
django接入新浪微博OAuth的方法
Jun 29 Python
Python实现的选择排序算法示例
Nov 29 Python
python/sympy求解矩阵方程的方法
Nov 08 Python
Python实现程序判断季节的代码示例
Jan 28 Python
python如何实现从视频中提取每秒图片
Oct 22 Python
Python列表的切片实例讲解
Aug 20 Python
Python 实现判断图片格式并转换,将转换的图像存到生成的文件夹中
Jan 13 Python
tensorflow 报错unitialized value的解决方法
Feb 06 Python
python烟花效果的代码实例
Feb 25 Python
Python GUI编程学习笔记之tkinter中messagebox、filedialog控件用法详解
Mar 30 Python
python线程池 ThreadPoolExecutor 的用法示例
Oct 10 Python
python tkinter实现连连看游戏
Nov 16 Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
You might like
php smarty 二级分类代码和模版循环例子
2011/06/16 PHP
PHP查询数据库中满足条件的记录条数(两种实现方法)
2013/01/29 PHP
PHP类的反射用法实例
2014/11/03 PHP
php验证session无效的解决方法
2014/11/04 PHP
php使用cookie实现记住用户名和密码实现代码
2015/04/27 PHP
PHP 中魔术常量的实例详解
2017/10/26 PHP
JavaScript中的16进制字符(改进)
2011/11/21 Javascript
js函数返回多个返回值的示例代码
2013/11/05 Javascript
JS实现简单的键盘打字的效果
2015/04/24 Javascript
AngularJS基础学习笔记之指令
2015/05/10 Javascript
JQuery给select添加/删除节点的实现代码
2016/04/26 Javascript
JavaScript实现简单的日历效果
2016/09/25 Javascript
前端 Vue.js 和 MVVM 详细介绍
2016/12/29 Javascript
angularJS模态框$modal实例代码
2017/05/27 Javascript
JS通过识别id、value值对checkbox设置选中状态
2020/02/19 Javascript
vue实现计算器功能
2020/02/22 Javascript
JS apply用法总结和使用场景实例分析
2020/03/14 Javascript
JavaScript判断数据类型有几种方法及区别介绍
2020/09/02 Javascript
js屏蔽F12审查元素,禁止修改页面代码等实现代码
2020/10/02 Javascript
[51:15]2014 DOTA2国际邀请赛中国区预选赛 Orenda VS LGD-GAMING
2014/05/22 DOTA
python使用7z解压apk包的方法
2015/04/18 Python
python实现在windows下操作word的方法
2015/04/28 Python
从Python的源码来解析Python下的freeblock
2015/05/11 Python
Python制作简单的网页爬虫
2015/11/22 Python
pandas带有重复索引操作方法
2018/06/08 Python
Python 装饰器原理、定义与用法详解
2019/12/07 Python
Python hashlib模块实例使用详解
2019/12/24 Python
pytorch方法测试——激活函数(ReLU)详解
2020/01/15 Python
在tensorflow中实现去除不足一个batch的数据
2020/01/20 Python
python openssl模块安装及用法
2020/12/06 Python
主办会计岗位职责
2014/03/13 职场文书
教师节联欢会主持词
2015/07/04 职场文书
小学语文课《掌声》教学反思
2016/03/03 职场文书
辞职信怎么写?你都知道吗?
2019/06/24 职场文书
解决Golang中goroutine执行速度的问题
2021/05/02 Golang
Go timer如何调度
2021/06/09 Golang