Pytorch 实现自定义参数层的例子


Posted in Python onAugust 17, 2019

注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。

官方Linear层:

class Linear(Module):
  def __init__(self, in_features, out_features, bias=True):
    super(Linear, self).__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = Parameter(torch.Tensor(out_features, in_features))
    if bias:
      self.bias = Parameter(torch.Tensor(out_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(1))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    return F.linear(input, self.weight, self.bias)

  def extra_repr(self):
    return 'in_features={}, out_features={}, bias={}'.format(
      self.in_features, self.out_features, self.bias is not None
    )

实现view层

class Reshape(nn.Module):
  def __init__(self, *args):
    super(Reshape, self).__init__()
    self.shape = args

  def forward(self, x):
    return x.view((x.size(0),)+self.shape)

实现LinearWise层

class LinearWise(nn.Module):
  def __init__(self, in_features, bias=True):
    super(LinearWise, self).__init__()
    self.in_features = in_features

    self.weight = nn.Parameter(torch.Tensor(self.in_features))
    if bias:
      self.bias = nn.Parameter(torch.Tensor(self.in_features))
    else:
      self.register_parameter('bias', None)
    self.reset_parameters()

  def reset_parameters(self):
    stdv = 1. / math.sqrt(self.weight.size(0))
    self.weight.data.uniform_(-stdv, stdv)
    if self.bias is not None:
      self.bias.data.uniform_(-stdv, stdv)

  def forward(self, input):
    x = input * self.weight
    if self.bias is not None:
      x = x + self.bias
    return x

以上这篇Pytorch 实现自定义参数层的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 类的继承实例详解
Mar 25 Python
python将unicode转为str的方法
Jun 21 Python
Python3.6 Schedule模块定时任务(实例讲解)
Nov 09 Python
python如何实现反向迭代
Mar 20 Python
使用DataFrame删除行和列的实例讲解
Apr 08 Python
Python实现京东秒杀功能代码
May 16 Python
Python中关于浮点数的冷知识
Sep 22 Python
Python散点图与折线图绘制过程解析
Nov 30 Python
python实现两个字典合并,两个list合并
Dec 02 Python
Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】
Dec 19 Python
Python中url标签使用知识点总结
Jan 16 Python
Django自带用户认证系统使用方法解析
Nov 12 Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
画pytorch模型图,以及参数计算的方法
Aug 17 #Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
You might like
改造一台复古桌面收音机
2021/03/02 无线电
PHP VS ASP
2006/10/09 PHP
如何在PHP中进行身份认证
2006/10/09 PHP
PHP 各种排序算法实现代码
2009/08/20 PHP
PHP中的array数组类型分析说明
2010/07/27 PHP
PHP持久连接mysql_pconnect()函数使用介绍
2012/02/05 PHP
destoon二次开发常用数据库操作
2014/06/21 PHP
laravel 修改记住我功能的cookie保存时间的方法
2019/10/14 PHP
PHP设计模式概论【概念、分类、原则等】
2020/05/01 PHP
location.href语句与火狐不兼容的问题
2010/07/04 Javascript
juqery 学习之四 筛选查找
2010/11/30 Javascript
js跨浏览器实现将字符串转化为xml对象的方法
2013/09/25 Javascript
javascript动态判断html元素并执行不同的操作
2014/06/16 Javascript
JavaScript中的类数组对象介绍
2014/12/30 Javascript
Flash图片上传组件 swfupload使用指南
2015/03/14 Javascript
js表格排序实例分析(支持int,float,date,string四种数据类型)
2015/05/06 Javascript
JS+CSS实现大气清新的滑动菜单效果代码
2015/10/22 Javascript
遍历js中对象的属性和值的实例
2016/11/21 Javascript
js canvas实现适用于移动端的百分比仪表盘dashboard
2017/07/18 Javascript
浅谈Javascript常用正则表达式应用
2019/03/08 Javascript
[00:34]TI7不朽珍藏III——地穴编织者不朽展示
2017/07/15 DOTA
Python基于回溯法子集树模板解决取物搭配问题实例
2017/09/02 Python
基于python requests库中的代理实例讲解
2018/05/07 Python
python 从文件夹抽取图片另存的方法
2018/12/04 Python
Pyqt5如何让QMessageBox按钮显示中文示例代码
2019/04/11 Python
Django 构建模板form表单的两种方法
2020/06/14 Python
HTML5是否真的可以取代Flash
2010/02/10 HTML / CSS
优质有机椰子产品:Dr. Goerg
2019/09/24 全球购物
大学毕业生简单自荐信
2013/11/05 职场文书
门卫工作岗位职责
2013/12/17 职场文书
药品促销活动方案
2014/02/14 职场文书
养成教育经验材料
2014/05/26 职场文书
餐饮服务食品安全责任书
2014/07/25 职场文书
数学教师个人总结
2015/02/06 职场文书
四群教育工作总结
2015/08/10 职场文书
Hive导入csv文件示例
2022/06/25 数据库