pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Flask SQLAlchemy一对一,一对多的使用方法实践
Feb 10 Python
Django中使用group_by的方法
May 26 Python
Python中的urllib模块使用详解
Jul 07 Python
python 实现上传图片并预览的3种方法(推荐)
Jul 14 Python
pandas groupby 分组取每组的前几行记录方法
Apr 20 Python
python识别图像并提取文字的实现方法
Jun 28 Python
Python 使用元类type创建类对象常见应用详解
Oct 17 Python
python定义类self用法实例解析
Jan 22 Python
解决pytorch-yolov3 train 报错的问题
Feb 18 Python
详解Python高阶函数
Aug 15 Python
10个python爬虫入门基础代码实例 + 1个简单的python爬虫完整实例
Dec 16 Python
virtualenv隔离Python环境的问题解析
Jun 21 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
一个比较简单的PHP 分页分组类
2009/12/10 PHP
php分页查询mysql结果的base64处理方法示例
2017/05/18 PHP
PHP4和PHP5版本下解析XML文档的操作方法实例分析
2017/05/20 PHP
PHP简单实现正则匹配省市区的方法
2018/04/13 PHP
YII框架学习笔记之命名空间、操作响应与视图操作示例
2019/04/30 PHP
Jquery 学习笔记(一)
2009/10/13 Javascript
JQueryiframe页面操作父页面中的元素与方法(实例讲解)
2013/11/19 Javascript
javascript实现数字+字母验证码的简单实例
2014/02/10 Javascript
JS实现单行文字不间断向上滚动的方法
2015/01/29 Javascript
jQuery实现仿腾讯视频列表分页效果的方法
2015/08/07 Javascript
jQuery实现选项卡切换效果简单演示
2015/12/09 Javascript
使用Bootstrap框架制作查询页面的界面实例代码
2016/05/27 Javascript
JavaScript实现复制文章自动添加版权
2016/08/02 Javascript
jQuery阻止移动端遮罩层后页面滚动
2017/03/15 Javascript
Vuejs入门教程之Vue生命周期,数据,手动挂载,指令,过滤器
2017/04/19 Javascript
JavaScript正则表达式的贪婪匹配和非贪婪匹配
2017/09/05 Javascript
vue2.0.js的多级联动选择器实现方法
2018/02/09 Javascript
使用JavaScript生成罗马字符的实例代码
2018/06/08 Javascript
Vue2.x Todo之自定义指令实现自动聚焦的方法
2019/01/08 Javascript
[01:45:05]VGJ.T vs Newbee Supermajor 败者组 BO3 第二场 6.6
2018/06/07 DOTA
python删除过期文件的方法
2015/05/29 Python
详解python之简单主机批量管理工具
2017/01/27 Python
python3获取当前文件的上一级目录实例
2018/04/26 Python
python读取excel指定列数据并写入到新的excel方法
2018/07/10 Python
对Python函数设计规范详解
2019/07/19 Python
使用TensorFlow实现简单线性回归模型
2019/07/19 Python
Python实现变声器功能(萝莉音御姐音)
2019/12/05 Python
python实现飞行棋游戏
2020/02/05 Python
HTML5 canvas基本绘图之文字渲染
2016/06/27 HTML / CSS
俄罗斯最大的消费电子连锁零售商:Mvideo
2017/06/25 全球购物
印度化妆品购物网站:Nykaa
2018/07/22 全球购物
C/C++ 笔试、面试题目大汇总
2015/11/21 面试题
实习自我鉴定模板
2013/09/28 职场文书
招商经理岗位职责
2013/11/16 职场文书
优秀大学生求职自荐信范文
2014/04/19 职场文书
企业晚会策划方案
2014/05/29 职场文书