浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现在线程里运行scrapy的方法
Apr 07 Python
Python简单检测文本类型的2种方法【基于文件头及cchardet库】
Sep 18 Python
一张图带我们入门Python基础教程
Feb 05 Python
Python实现树莓派WiFi断线自动重连的实例代码
Mar 16 Python
Python协程的用法和例子详解
Sep 09 Python
selenium+python自动化测试之页面元素定位
Jan 23 Python
pyqt5 禁止窗口最大化和禁止窗口拉伸的方法
Jun 18 Python
pandas DataFrame 交集并集补集的实现
Jun 24 Python
pandas实现to_sql将DataFrame保存到数据库中
Jul 03 Python
python中数据库like模糊查询方式
Mar 02 Python
浅谈matplotlib 绘制梯度下降求解过程
Jul 12 Python
Python如何把不同类型数据的json序列化
Apr 30 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
Codeigniter中mkdir创建目录遇到权限问题和解决方法
2014/07/25 PHP
smarty简单分页的实现方法
2014/10/27 PHP
PHP调用QQ互联接口实现QQ登录网站功能示例
2019/10/24 PHP
ie和firefox不兼容的解决方法集合
2009/04/28 Javascript
返回对象在当前级别中是第几个元素的实现代码
2011/01/20 Javascript
3款实用的在线JS代码工具(国外)
2012/03/15 Javascript
javascript实现可改变滚动方向的无缝滚动实例
2013/06/17 Javascript
jQuery中阻止冒泡事件的方法介绍
2014/04/12 Javascript
js操作模态窗口及父子窗口间相互传值示例
2014/06/09 Javascript
jQuery中用dom操作替代正则表达式
2014/12/29 Javascript
jQuery实现可用于博客的动态滑动菜单
2015/03/09 Javascript
javascript实现Table间隔色以及选择高亮(和动态切换数据)的方法
2015/05/14 Javascript
基于JavaScript实现图片点击弹出窗口而不是保存
2016/02/06 Javascript
Angular 根据 service 的状态更新 directive
2016/04/03 Javascript
简单理解vue中el、template、replace元素
2016/10/27 Javascript
JavaScript递归操作实例浅析
2016/10/31 Javascript
快速实现JS图片懒加载(可视区域加载)示例代码
2017/01/04 Javascript
Vue项目中跨域问题解决方案
2018/06/05 Javascript
优雅的在React项目中使用Redux的方法
2018/11/10 Javascript
[07:57]2018DOTA2国际邀请赛寻真——PSG.LGD凤凰浴火
2018/08/12 DOTA
Python3实现生成随机密码的方法
2014/08/23 Python
Python设计模式之备忘录模式原理与用法详解
2019/01/15 Python
python调用c++ ctype list传数组或者返回数组的方法
2019/02/13 Python
python使用wxpy实现微信消息防撤回脚本
2019/04/29 Python
Python GUI编程学习笔记之tkinter控件的介绍及基本使用方法详解
2020/03/30 Python
PyTorch的torch.cat用法
2020/06/28 Python
高三政治教学反思
2014/02/06 职场文书
青奥会口号
2014/06/12 职场文书
2014年办公室文员工作总结
2014/11/12 职场文书
暑期辅导班宣传单
2015/07/14 职场文书
2016银行招聘自荐信
2016/01/28 职场文书
Java面试题冲刺第十六天--消息队列
2021/08/07 面试题
Oracle 触发器trigger使用案例
2022/02/24 Oracle
PostgreSQL事务回卷实战案例详析
2022/03/25 PostgreSQL
python中 .npy文件的读写操作实例
2022/04/14 Python
MySQL数据库实验之 触发器和存储过程
2022/06/21 MySQL