pytorch 共享参数的示例


Posted in Python onAugust 17, 2019

在很多神经网络中,往往会出现多个层共享一个权重的情况,pytorch可以快速地处理权重共享问题。

例子1:

class ConvNet(nn.Module):
  def __init__(self):
    super(ConvNet, self).__init__()
    self.conv_weight = nn.Parameter(torch.randn(3, 3, 5, 5))
 
  def forward(self, x):
    x = nn.functional.conv2d(x, self.conv_weight, bias=None, stride=1, padding=2, dilation=1, groups=1)
    x = nn.functional.conv2d(x, self.conv_weight.transpose(2, 3).contiguous(), bias=None, stride=1, padding=0, dilation=1,
                 groups=1)
    return x

上边这段程序定义了两个卷积层,这两个卷积层共享一个权重conv_weight,第一个卷积层的权重是conv_weight本身,第二个卷积层是conv_weight的转置。注意在gpu上运行时,transpose()后边必须加上.contiguous()使转置操作连续化,否则会报错。

例子2:

class LinearNet(nn.Module):
  def __init__(self):
    super(LinearNet, self).__init__()
    self.linear_weight = nn.Parameter(torch.randn(3, 3))
 
  def forward(self, x):
    x = nn.functional.linear(x, self.linear_weight)
    x = nn.functional.linear(x, self.linear_weight.t())
 
    return x

这个网络实现了一个双层感知器,权重同样是一个parameter的本身及其转置。

例子3:

class LinearNet2(nn.Module):
  def __init__(self):
    super(LinearNet2, self).__init__()
    self.w = nn.Parameter(torch.FloatTensor([[1.1,0,0], [0,1,0], [0,0,1]]))
 
  def forward(self, x):
    x = x.mm(self.w)
    x = x.mm(self.w.t())
    return x

这个方法直接用mm函数将x与w相乘,与上边的网络效果相同。

以上这篇pytorch 共享参数的示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python通过正则表达式选取callback的方法
Jul 18 Python
python 第三方库的安装及pip的使用详解
May 11 Python
Queue 实现生产者消费者模型(实例讲解)
Nov 13 Python
python中判断文件编码的chardet(实例讲解)
Dec 21 Python
python随机取list中的元素方法
Apr 08 Python
基于DATAFRAME中元素的读取与修改方法
Jun 08 Python
Python实现绘制双柱状图并显示数值功能示例
Jun 23 Python
浅谈Python爬虫基本套路
Mar 25 Python
Gauss-Seidel迭代算法的Python实现详解
Jun 29 Python
python使用socket实现的传输demo示例【基于TCP协议】
Sep 24 Python
python进度条显示-tqmd模块的实现示例
Aug 23 Python
Python特殊属性property原理及使用方法解析
Oct 09 Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
You might like
在PHP中使用XML
2006/10/09 PHP
PHP 杂谈《重构-改善既有代码的设计》之一 重新组织你的函数
2012/04/09 PHP
深入php-fpm的两种进程管理模式详解
2013/06/03 PHP
php检测用户是否用手机(Mobile)访问网站的类
2014/01/09 PHP
ThinkPHP3.1新特性之对页面压缩输出的支持
2014/06/19 PHP
CI框架入门示例之数据库取数据完整实现方法
2014/11/05 PHP
彻底删除thinkphp3.1案例blog标签的方法
2014/12/05 PHP
基于JavaScript 下namespace 功能的简单分析
2013/07/05 Javascript
javascript实现原生ajax的几种方法介绍
2013/09/21 Javascript
简单的ajax连接库分享(不用jquery的ajax)
2014/01/19 Javascript
Jquery倒计时源码分享
2014/05/16 Javascript
js解决select下拉选不中问题
2014/10/14 Javascript
javascript检测flash插件是否被禁用的方法
2016/01/14 Javascript
jQuery实现滚动鼠标放大缩小图片的方法(附demo源码下载)
2016/03/05 Javascript
Javascript 引擎工作机制详解
2016/11/30 Javascript
详解VUE 定义全局变量的几种实现方式
2017/06/01 Javascript
javascript基本常用排序算法解析
2017/09/27 Javascript
代码详解JS操作剪贴板
2018/02/11 Javascript
vue 自定义 select内置组件
2018/04/10 Javascript
JavaScript和TypeScript中的void的具体使用
2019/09/12 Javascript
[52:41]OG vs IG 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/20 DOTA
使用url_helper简化Python中Django框架的url配置教程
2015/05/30 Python
python selenium自动上传有赞单号的操作方法
2018/07/05 Python
用python爬取租房网站信息的代码
2018/12/14 Python
Django模板Templates使用方法详解
2019/07/19 Python
django认证系统 Authentication使用详解
2019/07/22 Python
html5模拟平抛运动(模拟小球平抛运动过程)
2013/07/25 HTML / CSS
HTML5 Notification(桌面提醒)功能使用实例
2014/03/17 HTML / CSS
欧洲最大的品牌水上运动服装和设备在线零售商:Wuituit Outlet
2018/05/05 全球购物
购买原创艺术品:Zatista
2019/11/09 全球购物
VisionPros美国站:加拿大在线隐形眼镜和眼镜零售商
2020/02/11 全球购物
物业门卫岗位职责
2013/12/28 职场文书
晚宴邀请函范文
2014/01/15 职场文书
新闻编辑专业自荐信
2014/07/02 职场文书
《鸡兔同笼》教学反思
2016/02/19 职场文书
Android自定义双向滑动控件
2022/04/19 Java/Android