浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取脚本所在目录的正确方法
Apr 15 Python
python检测是文件还是目录的方法
Jul 03 Python
Python编程中的for循环语句学习教程
Oct 14 Python
python实现随机森林random forest的原理及方法
Dec 21 Python
Python常见数字运算操作实例小结
Mar 22 Python
十分钟搞定pandas(入门教程)
Jun 21 Python
Python中的几种矩阵乘法(小结)
Jul 10 Python
使用Filter过滤python中的日志输出的实现方法
Jul 17 Python
利用pyecharts读取csv并进行数据统计可视化的实现
Apr 17 Python
Python 字典一个键对应多个值的方法
Sep 29 Python
Ubuntu 20.04安装Pycharm2020.2及锁定到任务栏的问题(小白级操作)
Oct 29 Python
基于flask实现五子棋小游戏
May 25 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
网络资源
2006/10/09 PHP
基于thinkphp6.0的success、error实现方法
2019/11/05 PHP
jquery键盘事件介绍
2011/01/31 Javascript
jQuery代码优化 事件委托篇
2011/11/01 Javascript
Document:getElementsByName()使用方法及示例
2013/10/28 Javascript
jquery 操作css样式、位置、尺寸方法汇总
2014/11/28 Javascript
jQuery插件kinMaxShow扩展效果用法实例
2015/05/04 Javascript
JS实现的倒计时效果实例(2则实例)
2015/12/23 Javascript
AngularJs Javascript MVC 框架
2016/06/20 Javascript
利用JS提交表单的几种方法和验证(必看篇)
2016/09/17 Javascript
JavaScript基于扩展String实现替换字符串中index处字符的方法
2017/06/13 Javascript
vue打包后显示空白正确处理方法
2017/11/01 Javascript
React事件处理的机制及原理
2018/12/03 Javascript
详解微信小程序获取当前时间及日期的方法
2019/04/28 Javascript
微信小程序实现页面分享onShareAppMessage
2019/08/12 Javascript
layui 根据后台数据动态创建下拉框并同时默认选中的实例
2019/09/02 Javascript
layui 关闭open弹出框 刷新table表格页面的方法
2019/09/16 Javascript
通过高德地图API获得某条道路上的所有坐标用于描绘道路的方法
2020/08/24 Javascript
[54:45]2018DOTA2亚洲邀请赛 4.1 小组赛 A组 Optic vs OG
2018/04/02 DOTA
python实现通过代理服务器访问远程url的方法
2015/04/29 Python
Python中unittest模块做UT(单元测试)使用实例
2015/06/12 Python
Apache如何部署django项目
2017/05/21 Python
Python实现的堆排序算法示例
2018/04/29 Python
基于Python开发chrome插件的方法分析
2018/07/07 Python
深入浅析Python中list的复制及深拷贝与浅拷贝
2018/09/03 Python
Django框架自定义session处理操作示例
2019/05/27 Python
python GUI库图形界面开发之PyQt5线程类QThread详细使用方法
2020/02/26 Python
css3实现顶部社会化分享按钮示例
2014/05/06 HTML / CSS
日本即尚网:JSHOPPERS.com(支持中文)
2019/12/03 全球购物
机械制造与自动化应届生求职信
2013/11/16 职场文书
关于赌博的检讨书
2014/01/08 职场文书
书法比赛获奖感言
2014/02/10 职场文书
2014年社区卫生工作总结
2014/12/18 职场文书
横空出世观后感
2015/06/09 职场文书
2015重阳节敬老活动总结
2015/07/29 职场文书
高三物理教学反思
2016/02/20 职场文书