浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python base64编码解码实例
Jun 21 Python
Python的迭代器和生成器
Jul 29 Python
int在python中的含义以及用法
Jun 27 Python
Python实现的对一个数进行因式分解操作示例
Jun 27 Python
python opencv 简单阈值算法的实现
Aug 04 Python
为什么说Python可以实现所有的算法
Oct 04 Python
python3常用的数据清洗方法(小结)
Oct 31 Python
Anaconda+vscode+pytorch环境搭建过程详解
May 25 Python
使用opencv识别图像红色区域,并输出红色区域中心点坐标
Jun 02 Python
Python selenium爬取微信公众号文章代码详解
Aug 12 Python
Django解决frame拒绝问题的方法
Dec 18 Python
使用Python爬取小姐姐图片(beautifulsoup法)
Feb 11 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
php 字符串函数收集
2010/03/29 PHP
完美解决令人抓狂的zend studio 7代码提示(content Assist)速度慢的问题
2013/06/20 PHP
List Information About the Binary Files Used by an Application
2007/06/11 Javascript
学习ExtJS(一) 之基础前提
2009/10/07 Javascript
纯JS实现五子棋游戏兼容各浏览器(附源码)
2013/04/24 Javascript
谈谈我对JavaScript原型和闭包系列理解(随手笔记9)
2015/12/24 Javascript
倾力总结40条常见的移动端Web页面问题解决方案
2016/05/24 Javascript
Angular 中 select指令用法详解
2016/09/29 Javascript
js实现前端分页页码管理
2017/01/06 Javascript
JS基于正则实现数字千分位用逗号分隔的方法
2017/06/16 Javascript
vscode下的vue文件格式化问题
2018/11/28 Javascript
mocha的时序规则讲解
2019/02/16 Javascript
详解微信小程序胶囊按钮返回|首页自定义导航栏功能
2019/06/14 Javascript
vue如何实现自定义底部菜单栏
2019/07/01 Javascript
[04:11]DOTA2上海特级锦标赛主赛事首日TOP10
2016/03/03 DOTA
python实现哈希表
2014/02/07 Python
Python heapq使用详解及实例代码
2017/01/25 Python
Python基于numpy灵活定义神经网络结构的方法
2017/08/19 Python
Python简单实现查找一个字符串中最长不重复子串的方法
2018/03/26 Python
Django中使用Celery的教程详解
2018/08/24 Python
python使用selenium实现批量文件下载
2019/03/11 Python
Python代码实现删除一个list里面重复元素的方法
2019/04/02 Python
Python如何绘制日历图和热力图
2020/08/07 Python
举例讲解Python装饰器
2020/12/24 Python
详解CSS透明opacity和IE各版本透明度滤镜filter的最准确用法
2016/12/20 HTML / CSS
泰国演唱会订票网站:StubHub泰国
2018/02/26 全球购物
最好的意大利皮夹克:D’Arienzo
2018/12/04 全球购物
估算杭州有多少软件工程师
2015/08/11 面试题
财务会计实习报告体会
2013/12/20 职场文书
自我鉴定 电子商务专业
2014/01/30 职场文书
《和我们一样享受春天》教学反思
2014/02/07 职场文书
《姥姥的剪纸》教学反思
2014/02/25 职场文书
银行委托书范本
2014/09/28 职场文书
库房管理员岗位职责
2015/02/12 职场文书
国王的演讲观后感
2015/06/03 职场文书
浅谈Java实现分布式事务的三种方案
2021/06/11 Java/Android