pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中的魔法方法深入理解
Jul 09 Python
Python中实现从目录中过滤出指定文件类型的文件
Feb 02 Python
python数据清洗系列之字符串处理详解
Feb 12 Python
Python中字符串格式化str.format的详细介绍
Feb 17 Python
python如何通过实例方法名字调用方法
Mar 21 Python
Django objects的查询结果转化为json的三种方式的方法
Nov 07 Python
python调试神器PySnooper的使用
Jul 03 Python
python中append实例用法总结
Jul 30 Python
python3 字符串知识点学习笔记
Feb 08 Python
Keras Convolution1D与Convolution2D区别说明
May 22 Python
python删除文件、清空目录的实现方法
Sep 23 Python
python rsa-oaep加密的示例代码
Sep 23 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
在PHP里得到前天和昨天的日期的代码
2007/08/16 PHP
如何使用PHP实现javascript的escape和unescape函数
2013/06/29 PHP
php计算两个整数的最大公约数常用算法小结
2015/03/05 PHP
php 指定范围内多个随机数代码实例
2016/07/18 PHP
Yii2.0中的COOKIE和SESSION用法
2016/08/12 PHP
php PDO异常处理详解
2016/11/20 PHP
JavaScript实际应用:innerHTMl和确认提示的使用
2006/06/22 Javascript
Firefox 无法获取cssRules 的解决办法
2006/10/11 Javascript
JavaScript修改css样式style
2008/04/15 Javascript
JQUERY 设置SELECT选中项代码
2014/02/07 Javascript
jQuery 删除或是清空某个HTML元素示例
2014/08/04 Javascript
数据分析软件之FineReport教程:[5]参数界面JS(全)
2015/08/13 Javascript
js实现简单的联动菜单效果
2015/08/19 Javascript
jQuery 选择同时包含两个class的元素的实现方法
2016/06/01 Javascript
使用vue实现点击按钮滑出面板的实现代码
2017/01/10 Javascript
BootStrap Fileinput上传插件使用实例代码
2017/07/28 Javascript
webpack构建的详细流程探底
2018/01/08 Javascript
JavaScript代码实现txt文件的上传预览功能
2018/03/27 Javascript
Moment.js实现多个同时倒计时
2019/08/26 Javascript
layui表格内放置图片,并点击放大的实例
2019/09/10 Javascript
Vue使用Ref跨层级获取组件的步骤
2021/01/25 Vue.js
详细解析Python中的变量的数据类型
2015/05/13 Python
Python爬虫包BeautifulSoup异常处理(二)
2018/06/17 Python
python寻找list中最大值、最小值并返回其所在位置的方法
2018/06/27 Python
对Python 内建函数和保留字详解
2018/10/15 Python
Python数据持久化存储实现方法分析
2019/12/21 Python
行政总监岗位职责
2013/12/05 职场文书
大三学生入党思想汇报
2014/01/02 职场文书
小学红领巾中秋节广播稿
2014/01/13 职场文书
项目考察欢迎辞
2014/01/17 职场文书
小学新教师培训方案
2014/02/03 职场文书
光棍节联谊晚会活动策划书
2014/10/10 职场文书
2014年个人工作总结模板
2014/12/15 职场文书
2015年宣传工作总结
2015/04/08 职场文书
新年寄语2016
2015/08/17 职场文书
原生CSS实现文字无限轮播的通用方法
2021/03/30 HTML / CSS