Pytorch 实现自定义参数层的例子


Posted in Python onAugust 17, 2019

注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。

官方Linear层:

class Linear(Module):
  def __init__(self, in_features, out_features, bias=True):
    super(Linear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = Parameter(torch.Tensor(out_features, in_features))
    if bias:
      self.bias = Parameter(torch.Tensor(out_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(1))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    return F.linear(input, self.weight, self.bias)

  def extra_repr(self):
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None
    )

实现view层

class Reshape(nn.Module):
  def __init__(self, *args):
    super(Reshape, self).__init__()
    self.shape = args

  def forward(self, x):
    return x.view((x.size(0),)+self.shape)

实现LinearWise层

class LinearWise(nn.Module):
  def __init__(self, in_features, bias=True):
    super(LinearWise, self).__init__()
    self.in_features = in_features

    self.weight = nn.Parameter(torch.Tensor(self.in_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(self.in_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(0))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    x = input * self.weight
    if self.bias is not None:
      x = x + self.bias
    return x

以上这篇Pytorch 实现自定义参数层的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Turtle画正螺旋线的方法
Sep 22 Python
Python实现正整数分解质因数操作示例
Aug 01 Python
python 生成图形验证码的方法示例
Nov 11 Python
在Python文件中指定Python解释器的方法
Feb 18 Python
python mac下安装虚拟环境的图文教程
Apr 12 Python
Django结合ajax进行页面实时更新的例子
Aug 12 Python
python单例模式原理与创建方法实例分析
Oct 26 Python
Django框架表单操作实例分析
Nov 04 Python
python实现在多维数组中挑选符合条件的全部元素
Nov 26 Python
Python打包工具PyInstaller的安装与pycharm配置支持PyInstaller详细方法
Feb 27 Python
Python HTMLTestRunner可视化报告实现过程解析
Apr 10 Python
Python的Tqdm模块实现进度条配置
Feb 24 Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
You might like
用户的详细注册和判断
2006/10/09 PHP
一个PHP的QRcode类与大家分享
2011/11/13 PHP
基于php split()函数的用法详解
2013/06/05 PHP
关于PHP自动判断字符集并转码的详解
2013/06/26 PHP
php采集内容中带有图片地址的远程图片并保存的方法
2015/01/03 PHP
PHP给源代码加密的几种方法汇总(推荐)
2018/02/06 PHP
将HTML自动转为JS代码
2006/06/26 Javascript
javascript 内存回收机制理解
2011/01/17 Javascript
对象无length属性时IE6/IE7中无法将其转换成伪数组(ArrayLike)
2011/07/31 Javascript
原生JavaScript实现合并多个数组示例
2014/09/21 Javascript
AngularJs expression详解及简单示例
2016/09/01 Javascript
AngularJS实现动态编译添加到dom中的方法
2016/11/04 Javascript
JavaScript用构造函数如何获取变量的类型名
2016/12/23 Javascript
canvas实现刮刮卡效果
2017/03/14 Javascript
详解html-webpack-plugin用法全解
2018/01/22 Javascript
vue进行图片的预加载watch用法实例讲解
2018/02/07 Javascript
NodeJS简单实现WebSocket功能示例
2018/02/10 NodeJs
浅谈webpack打包之后的文件过大的解决方法
2018/03/07 Javascript
JavaScript的数据类型转换原则(干货)
2018/03/15 Javascript
微信小程序在线客服自动回复功能(基于node)
2019/07/03 Javascript
[01:32]寻找你心中的那团火 DOTA2 TI9火焰传递活动今日开启
2019/05/16 DOTA
python虚拟环境virtualenv的使用教程
2017/10/20 Python
基于Python检测动态物体颜色过程解析
2019/12/04 Python
python简单的三元一次方程求解实例
2020/04/02 Python
Python3读取和写入excel表格数据的示例代码
2020/06/09 Python
Python如何设置指定窗口为前台活动窗口
2020/08/12 Python
比利时香水网上商店:NOTINO
2018/03/28 全球购物
网络技术支持面试题
2013/04/22 面试题
编辑个人求职信范文
2013/09/21 职场文书
你懂得怎么写自荐信吗?
2013/12/27 职场文书
个人简历自我评价
2014/01/06 职场文书
社区科普工作方案
2014/06/03 职场文书
家庭经济困难证明
2015/06/23 职场文书
学子宴致辞大全
2015/07/27 职场文书
大学生村官工作心得体会
2016/01/23 职场文书
Vue CLI中模式与环境变量的深入详解
2021/05/30 Vue.js