浅析PyTorch中nn.Module的使用


Posted in Python onAugust 18, 2019

torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作

torch.nn.Module 是所有神经网络单元的基类

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

实例展示

简单搭建:

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self, n_feature, n_hidden, n_output):
    super(Net, self).__init__()
    self.hidden = nn.Linear(n_feature, n_hidden)
    self.out = nn.Linear(n_hidden, n_output)

  def forward(self, x):
    x = F.relu(self.hidden(x))
    x = self.out(x)
    return x

Net 类继承了 torch 的 Module 和 __init__ 功能

hidden 是隐藏层线性输出

out 是输出层线性输出

打印出网络的结构:

>>> net = Net(n_feature=10, n_hidden=30, n_output=15)
>>> print(net)
Net(
 (hidden): Linear(in_features=10, out_features=30, bias=True)
 (out): Linear(in_features=30, out_features=15, bias=True)
)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python安装官方whl包和tar.gz包的方法(推荐)
Jun 04 Python
python编程测试电脑开启最大线程数实例代码
Feb 09 Python
Python实用技巧之利用元组代替字典并为元组元素命名
Jul 11 Python
详解Python进阶之切片的误区与高级用法
Dec 24 Python
Python Django 命名空间模式的实现
Aug 09 Python
安装PyInstaller失败问题解决
Dec 14 Python
解决pycharm最左侧Tool Buttons显示不全的问题
Dec 17 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 Python
python3.7添加dlib模块的方法
Jul 01 Python
Python如何读写二进制数组数据
Aug 01 Python
利用python+request通过接口实现人员通行记录上传功能
Jan 13 Python
Pytest中skip和skipif的具体使用方法
Jun 30 Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
You might like
php使用GeoIP库实例
2014/06/27 PHP
phpmailer发送邮件之后,返回收件人是否阅读了邮件的方法
2014/07/19 PHP
php解析xml方法实例详解
2015/05/12 PHP
浅谈javascript的调试
2015/01/28 Javascript
javascript实现简单的贪吃蛇游戏
2015/03/31 Javascript
jQuery选择器源码解读(三):tokenize方法
2015/03/31 Javascript
jQuery超赞的评分插件(8款)
2015/08/20 Javascript
JS禁用页面上所有控件的实现方法(附demo源码下载)
2015/12/17 Javascript
jQuery实现textarea自动增长宽高的方法
2015/12/18 Javascript
JavaScript实现的MD5算法完整实例
2016/02/02 Javascript
创建一个类Person的简单实例
2016/05/17 Javascript
JS常用字符串方法(推荐)
2021/01/15 Javascript
Angular ng-repeat 对象和数组遍历实例
2016/09/14 Javascript
JS只能输入正整数的简单实例
2016/10/07 Javascript
前端JS面试中常见的算法问题总结
2016/12/23 Javascript
jquery.cookie.js的介绍与使用方法
2017/02/09 Javascript
Angular使用Restful的增删改
2018/12/28 Javascript
使用JS来动态操作css的几种方法
2019/12/18 Javascript
Vue两个版本的区别和使用方法(更深层次了解)
2020/02/16 Javascript
JavaScript常用工具函数大全
2020/05/06 Javascript
React实现阿里云OSS上传文件的示例
2020/08/10 Javascript
Python实现的多项式拟合功能示例【基于matplotlib】
2018/05/15 Python
pycharm 主题theme设置调整仿sublime的方法
2018/05/23 Python
在Python中使用turtle绘制多个同心圆示例
2019/11/23 Python
python使用正则表达式去除中文文本多余空格,保留英文之间空格方法详解
2020/02/11 Python
Python pip install之SSL异常处理操作
2020/09/03 Python
Tahari ASL官方网站:高级设计师女装
2021/03/15 全球购物
this关键字的含义
2015/04/08 面试题
laravel使用redis队列实例讲解
2021/03/23 PHP
怎样写好自我评价呢?
2014/02/16 职场文书
英文演讲稿开场白
2014/08/25 职场文书
机动车交通事故协议书
2015/01/29 职场文书
MySQL高速缓存启动方法及参数详解(query_cache_size)
2021/07/01 MySQL
SSM VUE Axios详解
2021/10/05 Vue.js
忘记Grafana不要紧2种Grafana重置admin密码方法详细步骤
2022/04/07 Servers
Spring JPA 增加字段执行异常问题及解决
2022/06/10 Java/Android