浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用装饰器和线程限制函数执行时间的方法
Apr 18 Python
Python中threading模块join函数用法实例分析
Jun 04 Python
利用Python暴力破解zip文件口令的方法详解
Dec 21 Python
python 接口测试response返回数据对比的方法
Feb 11 Python
python计算两个数的百分比方法
Jun 29 Python
Python补齐字符串长度的实例
Nov 15 Python
python xpath获取页面注释的方法
Jan 14 Python
Python中extend和append的区别讲解
Jan 24 Python
Python装饰器限制函数运行时间超时则退出执行
Apr 09 Python
Python中使用pypdf2合并、分割、加密pdf文件的代码详解
May 21 Python
python递归下载文件夹下所有文件
Aug 31 Python
python使用pip安装SciPy、SymPy、matplotlib教程
Nov 20 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
php正则修正符用法实例详解
2016/12/29 PHP
由php中字符offset特征造成的绕过漏洞详解
2017/07/07 PHP
js原生态函数中使用jQuery中的 $(this)无效的解决方法
2011/05/25 Javascript
jQuery EasyUI API 中文文档 - Pagination分页
2011/09/29 Javascript
jquery提取元素里的纯文本不包含span等里的内容
2013/09/30 Javascript
js实现刷新iframe的方法汇总
2015/04/27 Javascript
JS实现按比例缩放图片的方法(附C#版代码)
2015/12/08 Javascript
Bootstrap每天必学之响应式导航、轮播图
2016/04/25 Javascript
jquery制做精致的倒计时特效
2016/06/13 Javascript
如何使用vuejs实现更好的Form validation?
2017/04/07 Javascript
详解JSON和JSONP劫持以及解决方法
2019/03/08 Javascript
微信小程序合法域名配置方法
2019/05/06 Javascript
Vue scrollBehavior 滚动行为实现后退页面显示在上次浏览的位置
2019/05/27 Javascript
解决ele ui 表格表头太长问题的实现
2019/11/13 Javascript
Python入门篇之正则表达式
2014/10/20 Python
使用Python的Tornado框架实现一个Web端图书展示页面
2016/07/11 Python
详解MySQL数据类型int(M)中M的含义
2016/11/20 Python
Python爬虫实现网页信息抓取功能示例【URL与正则模块】
2017/05/18 Python
python实现简易版计算器
2020/06/22 Python
使用python实现kNN分类算法
2019/10/16 Python
Python安装依赖(包)模块方法详解
2020/02/14 Python
python+selenium 脚本实现每天自动登记的思路详解
2020/03/11 Python
Elasticsearch py客户端库安装及使用方法解析
2020/09/14 Python
Python django框架 web端视频加密的实例详解
2020/11/20 Python
Python项目打包成二进制的方法
2020/12/30 Python
纯CSS3实现绘制各种图形实现代码详细整理
2012/12/26 HTML / CSS
家得宝官网:The Home Depot(全球最大的家居装饰专业零售商)
2018/12/17 全球购物
亚马逊新加坡官方网站:Amazon.sg
2020/03/25 全球购物
如何开发安全的AJAX应用
2014/03/26 面试题
违反学校规定检讨书
2014/01/18 职场文书
晨会主持词
2014/03/17 职场文书
联谊活动总结
2014/08/28 职场文书
2014年幼儿园教师工作总结
2014/11/08 职场文书
推普标语口号大全
2015/12/26 职场文书
MySQL数据库 安全管理
2022/05/06 MySQL
tomcat下部署jenkins的方法
2022/05/06 Servers