Pytorch 实现自定义参数层的例子


Posted in Python onAugust 17, 2019

注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。

官方Linear层:

class Linear(Module):
  def __init__(self, in_features, out_features, bias=True):
    super(Linear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = Parameter(torch.Tensor(out_features, in_features))
    if bias:
      self.bias = Parameter(torch.Tensor(out_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(1))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    return F.linear(input, self.weight, self.bias)

  def extra_repr(self):
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None
    )

实现view层

class Reshape(nn.Module):
  def __init__(self, *args):
    super(Reshape, self).__init__()
    self.shape = args

  def forward(self, x):
    return x.view((x.size(0),)+self.shape)

实现LinearWise层

class LinearWise(nn.Module):
  def __init__(self, in_features, bias=True):
    super(LinearWise, self).__init__()
    self.in_features = in_features

    self.weight = nn.Parameter(torch.Tensor(self.in_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(self.in_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(0))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    x = input * self.weight
    if self.bias is not None:
      x = x + self.bias
    return x

以上这篇Pytorch 实现自定义参数层的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的hypot()方法使用简介
May 18 Python
Python实现的基数排序算法原理与用法实例分析
Nov 23 Python
python爬虫爬取淘宝商品信息
Feb 23 Python
python 日志增量抓取实现方法
Apr 28 Python
详解Python做一个名片管理系统
Mar 14 Python
python logging模块书写日志以及日志分割详解
Jul 22 Python
python+OpenCV实现图像拼接
Mar 05 Python
python中执行smtplib失败的处理方法
Jul 01 Python
Python pysnmp使用方法及代码实例
Aug 24 Python
python爬虫智能翻页批量下载文件的实例详解
Feb 02 Python
Python文件的操作示例的详细讲解
Apr 08 Python
python开发实时可视化仪表盘的示例
May 07 Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
You might like
PHP基于mcript扩展实现对称加密功能示例
2019/02/21 PHP
Yii2.0框架behaviors方法使用实例分析
2019/09/30 PHP
Javascript调试工具(下载)
2007/01/09 Javascript
jQuery 性能优化指南 (1)
2009/05/21 Javascript
轻松创建nodejs服务器(1):一个简单nodejs服务器例子
2014/12/18 NodeJs
情人节单身的我是如何在敲完代码之后收到12束玫瑰的(javascript)
2015/08/21 Javascript
nodejs实现bigpipe异步加载页面方案
2016/01/26 NodeJs
在html中引入外部js文件,并调用带参函数的方法
2016/10/31 Javascript
微信小程序 实现拖拽事件监听实例详解
2016/11/16 Javascript
JS触摸事件、手势事件详解
2017/05/04 Javascript
Redux实现组合计数器的示例代码
2018/07/04 Javascript
微信小程序实现元素渐入渐出动画效果封装方法
2019/05/18 Javascript
[07:12]2014DOTA2西雅图国际邀请赛 黑马Liquid专题采访
2014/07/12 DOTA
Python发送http请求解析返回json的实例
2018/03/26 Python
Python实现爬取马云的微博功能示例
2019/02/16 Python
Python3 文章标题关键字提取的例子
2019/08/26 Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
2020/06/28 Python
Python 创建守护进程的示例
2020/09/29 Python
python用tkinter实现一个gui的翻译工具
2020/10/26 Python
CSS3实现文字波浪线效果示例代码
2016/11/20 HTML / CSS
大学自荐信
2013/12/12 职场文书
给女儿的表扬信
2014/01/18 职场文书
建筑系毕业生自我鉴定
2014/01/24 职场文书
高中生期末评语大全
2014/01/28 职场文书
开学典礼感言
2014/02/16 职场文书
企业管理毕业生求职信范文
2014/03/07 职场文书
小学生作文评语大全
2014/04/21 职场文书
社会调查研究计划书
2014/05/01 职场文书
化学专业大学生职业生涯规划范文
2014/09/13 职场文书
物流专业专科生职业生涯规划书
2014/09/14 职场文书
党的群众路线教育实践活动学习笔记范文
2014/11/06 职场文书
三年级上册科学教学计划
2015/01/21 职场文书
医生个人年终总结
2015/02/28 职场文书
拾金不昧表扬稿大全
2015/05/05 职场文书
《画家和牧童》教学反思
2016/02/17 职场文书
MySQL数据库查询之多表查询总结
2022/08/05 MySQL