Pytorch 实现自定义参数层的例子


Posted in Python onAugust 17, 2019

注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。

官方Linear层:

class Linear(Module):
  def __init__(self, in_features, out_features, bias=True):
    super(Linear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = Parameter(torch.Tensor(out_features, in_features))
    if bias:
      self.bias = Parameter(torch.Tensor(out_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(1))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    return F.linear(input, self.weight, self.bias)

  def extra_repr(self):
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None
    )

实现view层

class Reshape(nn.Module):
  def __init__(self, *args):
    super(Reshape, self).__init__()
    self.shape = args

  def forward(self, x):
    return x.view((x.size(0),)+self.shape)

实现LinearWise层

class LinearWise(nn.Module):
  def __init__(self, in_features, bias=True):
    super(LinearWise, self).__init__()
    self.in_features = in_features

    self.weight = nn.Parameter(torch.Tensor(self.in_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(self.in_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(0))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    x = input * self.weight
    if self.bias is not None:
      x = x + self.bias
    return x

以上这篇Pytorch 实现自定义参数层的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python利用不到一百行代码实现一个小siri
Mar 02 Python
浅析使用Python操作文件
Jul 31 Python
python实现机器学习之元线性回归
Sep 06 Python
实例讲解python中的协程
Oct 08 Python
Django objects的查询结果转化为json的三种方式的方法
Nov 07 Python
python学生信息管理系统实现代码
Dec 17 Python
python 遗传算法求函数极值的实现代码
Feb 11 Python
tensorflow之tf.record实现存浮点数数组
Feb 17 Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 Python
Python爬虫设置ip代理过程解析
Jul 20 Python
python和C++共享内存传输图像的示例
Oct 27 Python
关于python中remove的一些坑小结
Jan 04 Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
You might like
PHP脚本的10个技巧(4)
2006/10/09 PHP
Codeigniter上传图片出现“You did not select a file to upload”错误解决办法
2014/06/12 PHP
PHP基于接口技术实现简单的多态应用完整实例
2017/04/26 PHP
js 与或运算符 || && 妙用
2009/12/09 Javascript
Javascript Objects详解
2014/09/04 Javascript
JavaScript实现页面跳转的几种常用方式
2015/11/28 Javascript
浅谈jQuery中hide和fadeOut的区别 show和fadeIn的区别
2016/08/18 Javascript
jQuery中Find选择器用法示例
2016/09/21 Javascript
Angular4实现动态添加删除表单输入框功能
2017/08/11 Javascript
vue中父子组件注意事项,传值及slot应用技巧
2018/05/09 Javascript
node上的redis调用优化示例详解
2018/10/30 Javascript
通过JS深度判断两个对象字段相同
2019/06/14 Javascript
vue中beforeRouteLeave实现页面回退不刷新的示例代码
2019/11/01 Javascript
JS实现音乐钢琴特效
2020/01/06 Javascript
[00:35]可解锁地面特效
2018/12/20 DOTA
linux环境下安装pyramid和新建项目的步骤
2013/11/27 Python
python的格式化输出(format,%)实例详解
2018/06/01 Python
Python函数的参数常见分类与用法实例详解
2019/03/30 Python
python实现简单日期工具类
2019/04/24 Python
pytorch的梯度计算以及backward方法详解
2020/01/10 Python
pycharm进入时每次都是insert模式的解决方式
2021/02/05 Python
SIXPAD智能健身仪英国官网:革命性的训练装备品牌
2018/09/27 全球购物
综合素质的自我鉴定
2013/10/07 职场文书
大专生简历的自我评价
2013/11/26 职场文书
葡萄牙语专业个人求职信
2013/12/10 职场文书
警察思想汇报
2014/01/04 职场文书
乡村卫生服务一体化管理实施方案
2014/03/30 职场文书
初一学生期末评语
2014/04/24 职场文书
英语故事演讲稿
2014/04/29 职场文书
勤奋学习演讲稿
2014/05/10 职场文书
服务员态度差检讨书
2014/10/28 职场文书
迎新生晚会主持词
2015/06/30 职场文书
房产销售员2015年终工作总结
2015/10/22 职场文书
使用Redis实现秒杀功能的简单方法
2021/05/08 Redis
MySQL创建定时任务
2022/01/22 MySQL
Python使用OpenCV实现虚拟缩放效果
2022/02/28 Python