pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python中sleep函数用法实例分析
Apr 29 Python
深入理解Python中命名空间的查找规则LEGB
Aug 06 Python
Python基于pygame实现的font游戏字体(附源码)
Nov 11 Python
sublime text 3配置使用python操作方法
Jun 11 Python
Ubuntu 下 vim 搭建python 环境 配置
Jun 12 Python
Python使用matplotlib和pandas实现的画图操作【经典示例】
Jun 13 Python
python抓取网页内容并进行语音播报的方法
Dec 24 Python
python实现远程控制电脑
May 23 Python
在django中实现页面倒数几秒后自动跳转的例子
Aug 16 Python
python turtle 绘制太极图的实例
Dec 18 Python
python模拟实现分发扑克牌
Apr 22 Python
django正续或者倒序查库实例
May 19 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
php5.3 注意事项说明
2013/07/01 PHP
php实现用于验证所有类型的信用卡类
2015/03/24 PHP
解决laravel groupBy 对查询结果进行分组出现的问题
2019/10/09 PHP
javascript 写的一个简单的timer
2009/07/30 Javascript
学习JS面向对象成果 借国庆发布个最新作品与大家交流
2009/10/03 Javascript
深入理解javascript中defer的作用
2013/12/11 Javascript
JavaScript中的逻辑判断符&&、||与!介绍
2014/12/31 Javascript
Angularjs中如何使用filterFilter函数过滤
2016/02/06 Javascript
Jquery EasyUI实现treegrid上显示checkbox并取选定值的方法
2016/04/29 Javascript
Extjs让combobox写起来简洁又漂亮
2017/01/05 Javascript
JS实现的添加弹出层并完成锁屏操作示例
2017/04/07 Javascript
Vue.js 的移动端组件库mint-ui实现无限滚动加载更多的方法
2017/12/23 Javascript
angularJs复选框checkbox选中进行ng-show显示隐藏的方法
2018/10/08 Javascript
轻松解决JavaScript定时器越走越快的问题
2019/05/13 Javascript
如何用原生js写一个弹窗消息提醒插件
2019/05/24 Javascript
layer更改皮肤的实现方法
2019/09/11 Javascript
浅谈vue-router路由切换 组件重用挖下的坑
2019/11/01 Javascript
JS实现拖拽元素时与另一元素碰撞检测
2020/08/27 Javascript
[33:42]LGD vs OG 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
使用Python实现一个简单的项目监控
2015/03/31 Python
Python IDE PyCharm的基本快捷键和配置简介
2015/11/04 Python
Python使用文件锁实现进程间同步功能【基于fcntl模块】
2017/10/16 Python
python正向最大匹配分词和逆向最大匹配分词的实例
2018/11/14 Python
Django模板Templates使用方法详解
2019/07/19 Python
css3实现动画的三种方式
2020/08/24 HTML / CSS
意大利会呼吸的鞋:Geox健乐士
2017/02/12 全球购物
德国传统玻璃制造商:Cristalica
2018/04/23 全球购物
比较一下entity bean和session bean
2013/12/27 面试题
计算机专业毕业生求职信
2014/04/30 职场文书
机械工程师岗位职责
2014/06/16 职场文书
2014领导班子专题民主生活会对照检查材料思想汇报
2014/09/23 职场文书
2014年机关作风建设工作总结
2014/10/23 职场文书
体育部部长竞选稿
2015/11/21 职场文书
Python如何利用正则表达式爬取网页信息及图片
2021/04/17 Python
JS Canvas接口和动画效果大全
2021/04/29 Javascript
pt-archiver 主键自增
2022/04/26 MySQL