浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将多个文本文件合并为一个文本的代码(便于搜索)
Mar 13 Python
在Python下尝试多线程编程
Apr 28 Python
Python判断是否json是否包含一个key的方法
Dec 31 Python
PyQt5实现QLineEdit添加clicked信号的方法
Jun 25 Python
Python 计算任意两向量之间的夹角方法
Jul 05 Python
PyTorch之图像和Tensor填充的实例
Aug 18 Python
Django项目使用ckeditor详解(不使用admin)
Dec 17 Python
python GUI编程(Tkinter) 创建子窗口及在窗口上用图片绘图实例
Mar 04 Python
Python Dataframe常见索引方式详解
May 27 Python
基于Python实现全自动下载抖音视频
Nov 06 Python
python文件名批量重命名脚本实例代码
Apr 22 Python
PYTHON基于Pyecharts绘制常见的直角坐标系图表
Apr 28 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
php实现ping
2006/10/09 PHP
PHP4与PHP5的时间格式问题
2008/02/17 PHP
php 动态多文件上传
2009/01/18 PHP
基于thinkPHP实现的微信自定义分享功能示例
2016/09/23 PHP
PHP实现的只保留字符串首尾字符功能示例【隐藏部分字符串】
2019/03/11 PHP
javascript Base类 包含基本的方法
2009/07/22 Javascript
JavaScript 基于原型的对象(创建、调用)
2009/10/16 Javascript
js防止页面被iframe调用的方法
2014/10/30 Javascript
整理Javascript基础入门学习笔记
2015/11/29 Javascript
JS中如何比较两个Json对象是否相等实例代码
2016/07/13 Javascript
vuejs 单文件组件.vue 文件的使用
2017/07/28 Javascript
vue移动UI框架滑动加载数据的方法
2018/03/12 Javascript
JS中的防抖与节流及作用详解
2019/04/01 Javascript
用node.js写一个jenkins发版脚本
2019/05/21 Javascript
js实现提交前对列表数据的增删改查
2020/01/16 Javascript
微信小程序分享小程序码的生成(带参数)以及参数的获取
2020/03/25 Javascript
vue addRoutes路由动态加载操作
2020/08/04 Javascript
[05:34]2014DOTA2国际邀请赛中国区预选赛精彩TOPPLAY第二弹
2014/06/25 DOTA
Python中的异常处理简明介绍
2015/04/13 Python
Java分治归并排序算法实例详解
2017/12/12 Python
python 识别图片中的文字信息方法
2018/05/10 Python
pygame实现雷电游戏雏形开发
2018/11/20 Python
Python实现微信消息防撤回功能的实例代码
2019/04/29 Python
Django框架orM与自定义SQL语句混合事务控制操作
2019/06/27 Python
python创建ArcGIS shape文件的实现
2019/12/06 Python
Tensorflow轻松实现XOR运算的方式
2020/02/03 Python
python tqdm 实现滚动条不上下滚动代码(保持一行内滚动)
2020/02/19 Python
python numpy库np.percentile用法说明
2020/06/08 Python
解决Keras的自定义lambda层去reshape张量时model保存出错问题
2020/07/01 Python
Python Selenium自动化获取页面信息的方法
2020/08/31 Python
逃课检讨书范文
2015/05/06 职场文书
运动会闭幕式主持词
2015/07/01 职场文书
学校教学管理制度
2015/08/06 职场文书
Html5获取用户当前位置的几种方式
2022/01/18 HTML / CSS
python接口测试返回数据为字典取值方式
2022/02/12 Python
无线电知识基础入门篇
2022/02/18 无线电