浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python random模块常用方法
Nov 03 Python
Python实现数据库并行读取和写入实例
Jun 09 Python
利用python编写一个图片主色转换的脚本
Dec 07 Python
Python爬虫的两套解析方法和四种爬虫实现过程
Jul 20 Python
在Pycharm中对代码进行注释和缩进的方法详解
Jan 20 Python
Pandas之DataFrame对象的列和索引之间的转化
Jun 25 Python
Python 使用 docopt 解析json参数文件过程讲解
Aug 13 Python
详解Python的三种拷贝方式
Feb 11 Python
Python浮点型(float)运算结果不正确的解决方案
Sep 22 Python
浅谈怎么给Python添加类型标注
Jun 08 Python
Python使用Beautiful Soup(BS4)库解析HTML和XML
Jun 05 Python
python pandas 解析(读取、写入)CSV 文件的操作方法
Dec 24 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
php学习之 循环结构实现代码
2011/06/09 PHP
PHP判断文件是否存在、是否可读、目录是否存在的代码
2012/10/03 PHP
php通过array_merge()函数合并关联和非关联数组的方法
2015/03/18 PHP
浅谈PHP中JSON数据操作
2015/07/01 PHP
解决php用mysql方式连接数据库出现Deprecated报错问题
2019/12/25 PHP
CSS心形加载的动画源码的实现
2021/03/09 HTML / CSS
JavaScript面向对象编程
2008/03/02 Javascript
javascript面向对象入门基础详细介绍
2012/09/05 Javascript
jquery实现table鼠标经过变色代码
2013/09/25 Javascript
jquery网页回到顶部效果(图标渐隐,自写)
2014/06/16 Javascript
JS实现漂亮的淡蓝色滑动门效果代码
2015/09/23 Javascript
JavaScript获取图片像素颜色并转换为box-shadow显示
2016/03/11 Javascript
Vue.js移动端左滑删除组件的实现代码
2017/09/08 Javascript
微信小程序自定义组件实现tabs选项卡功能
2018/07/14 Javascript
nodejs aes 加解密实例
2018/10/10 NodeJs
原生js实现Flappy Bird小游戏
2018/12/24 Javascript
详解vue 图片上传功能
2019/04/30 Javascript
layui的layedit富文本赋值方法
2019/09/18 Javascript
JS正则表达式常见函数与用法小结
2020/04/13 Javascript
使用Python编写简单的端口扫描器的实例分享
2015/12/18 Python
python中函数默认值使用注意点详解
2016/06/01 Python
ubuntu安装mysql pycharm sublime
2018/02/20 Python
python文件操作之批量修改文件后缀名的方法
2018/08/10 Python
Python3列表内置方法大全及示例代码小结
2019/05/10 Python
基于Numpy.convolve使用Python实现滑动平均滤波的思路详解
2019/05/16 Python
python实现WebSocket服务端过程解析
2019/10/18 Python
Python Json数据文件操作原理解析
2020/05/09 Python
中国旅游网站:同程旅游
2016/09/11 全球购物
Baracuta官方网站:Harrington夹克,G9,G4,G10等
2018/03/06 全球购物
大学生个人总结的自我评价
2013/10/05 职场文书
入党自荐书范文
2014/03/09 职场文书
幼儿园六一儿童节主持节目串词
2014/03/21 职场文书
建国大业观后感800字
2015/06/01 职场文书
教你快速开启Apache SkyWalking的自监控
2021/04/25 Servers
SpringBoot快速入门详解
2021/07/21 Java/Android
tomcat正常启动但网页却无法访问的几种解决方法
2022/05/06 Servers