pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python发送邮件接收邮件示例分享
Jan 21 Python
用Python实现一个简单的能够上传下载的HTTP服务器
May 05 Python
Linux 发邮件磁盘空间监控(python)
Apr 23 Python
django模型层(model)进行建表、查询与删除的基础教程
Nov 21 Python
使用python绘制3维正态分布图的方法
Dec 29 Python
python交换两个变量的值方法
Jan 12 Python
python django下载大的csv文件实现方法分析
Jul 19 Python
基于python实现自动化办公学习笔记(CSV、word、Excel、PPT)
Aug 06 Python
Python中sorted()排序与字母大小写的问题
Jan 14 Python
Python turtle画图库&&画姓名实例
Jan 19 Python
python如何删除列为空的行
Jul 17 Python
基于Python实现将列表数据生成折线图
Mar 23 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
如何开始收听短波广播
2021/03/01 无线电
基于php中使用excel的简单介绍
2013/08/02 PHP
PHP计算指定日期所在周的开始和结束日期的方法
2015/03/24 PHP
php实现构建排除当前元素的乘积数组方法
2018/10/06 PHP
JavaScript基本对象
2007/01/11 Javascript
Confirmer JQuery确认对话框组件
2010/06/09 Javascript
基于jQuery的为attr添加id title等效果的实现代码
2011/04/20 Javascript
js+cookies实现悬浮购物车的方法
2015/05/25 Javascript
JavaScript File API文件上传预览
2016/02/02 Javascript
JQuery.validate在ie8下不支持的快速解决方法
2016/05/18 Javascript
jQuery实现的超链接提示效果示例【附demo源码下载】
2016/09/09 Javascript
基于touch.js手势库+zepto.js插件开发图片查看器(滑动、缩放、双击缩放)
2016/11/17 Javascript
jQuery EasyUI 获取tabs的实例解析
2016/12/06 Javascript
React中如何引入Angular组件详解
2018/08/09 Javascript
Vue CL3 配置路径别名详解
2019/05/30 Javascript
初试vue-cli使用HBuilderx打包app的坑
2019/07/17 Javascript
Vue 实现复制功能,不需要任何结构内容直接复制方式
2019/11/09 Javascript
javascript严格模式详解(含严格模式与非严格模式的区别)
2019/11/12 Javascript
Javascript原型链及instanceof原理详解
2020/05/25 Javascript
Vue中keep-alive组件的深入理解
2020/08/23 Javascript
python中redis查看剩余过期时间及用正则通配符批量删除key的方法
2018/07/30 Python
详解用python实现基本的学生管理系统(文件存储版)(python3)
2019/04/25 Python
python实现在多维数组中挑选符合条件的全部元素
2019/11/26 Python
Django中Q查询及Q()对象 F查询及F()对象用法
2020/07/09 Python
CSS3控制HTML元素动画效果
2014/02/08 HTML / CSS
大学生标准推荐信范文
2013/11/25 职场文书
省级四好少年事迹材料
2014/01/25 职场文书
迎八一活动主题
2014/01/31 职场文书
模范教师材料大全
2014/12/16 职场文书
经理岗位职责范本
2015/04/15 职场文书
小学生暑假生活总结
2015/07/13 职场文书
学校运动会简讯
2015/07/20 职场文书
python scrapy简单模拟登录的代码分析
2021/07/21 Python
SQL CASE 表达式的具体使用
2022/03/21 SQL Server
剑指Offer之Java算法习题精讲二叉树专项训练
2022/03/21 Java/Android
Python实现提取PDF简历信息并存入Excel
2022/04/02 Python