pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
python Django连接MySQL数据库做增删改查
Nov 07 Python
Python 正则表达式实现计算器功能
Apr 29 Python
Python学习小技巧总结
Jun 10 Python
Tensorflow 同时载入多个模型的实例讲解
Jul 27 Python
python 类的继承 实例方法.静态方法.类方法的代码解析
Aug 23 Python
Python高级编程之消息队列(Queue)与进程池(Pool)实例详解
Nov 01 Python
TensorFlow索引与切片的实现方法
Nov 20 Python
Python使用Pyqt5实现简易浏览器(最新版本测试过)
Apr 27 Python
python下对hsv颜色空间进行量化操作
Jun 04 Python
Python实现Canny及Hough算法代码实例解析
Aug 06 Python
Python json解析库jsonpath原理及使用示例
Nov 25 Python
Matplotlib绘制混淆矩阵的实现
May 27 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
收音机指标测试方法及仪器
2021/03/01 无线电
PHP与SQL注入攻击[二]
2007/04/17 PHP
2014过年倒计时示例
2014/01/31 PHP
ThinkPHP框架表单验证操作方法
2017/07/19 PHP
PHP实现类似题库抽题效果
2018/08/16 PHP
代码精简的可以实现元素圆角的js函数
2007/07/21 Javascript
深入Javascript函数、递归与闭包(执行环境、变量对象与作用域链)使用详解
2013/05/08 Javascript
JavaScript中prototype为对象添加属性的误区介绍
2013/10/15 Javascript
js判断当页面无法回退时关闭网页否则就history.go(-1)
2014/08/07 Javascript
NodeJs基本语法和类型
2015/02/13 NodeJs
CSS中position属性之fixed实现div居中
2015/12/14 Javascript
AngularJS使用ngMessages进行表单验证
2015/12/27 Javascript
jQuery实现智能判断固定导航条或侧边栏的方法
2016/09/04 Javascript
Javascript 实现简单计算器实例代码
2016/10/23 Javascript
Angular.js中ng-if、ng-show和ng-hide的区别介绍
2017/01/20 Javascript
js 数据存储和DOM编程
2017/02/09 Javascript
Nodejs--post的公式详解
2017/04/29 NodeJs
ES6 javascript的异步操作实例详解
2017/10/30 Javascript
在react-router4中进行代码拆分的方法(基于webpack)
2018/03/08 Javascript
如何使node也支持从url加载一个module详解
2018/06/05 Javascript
Node.js中读取TXT文件内容fs.readFile()用法
2018/10/10 Javascript
ES6中的Javascript解构的实现
2020/10/30 Javascript
js实现验证码干扰(静态)
2021/02/22 Javascript
简单理解Python中基于生成器的状态机
2015/04/13 Python
python方法生成txt标签文件的实例代码
2018/05/10 Python
Python实现的tcp端口检测操作示例
2018/07/24 Python
Python Flask上下文管理机制实例解析
2020/03/16 Python
Django 用户登陆访问限制实例 @login_required
2020/05/13 Python
一款利用css3的鼠标经过动画显示详情特效的实例教程
2014/12/29 HTML / CSS
深入浅析HTML5中的SVG
2015/11/27 HTML / CSS
总结表彰大会主持词
2014/03/26 职场文书
祖国在我心中演讲稿400字
2014/05/04 职场文书
个人学习党的群众路线教育实践活动心得体会
2014/11/05 职场文书
2014年青年志愿者工作总结
2014/12/09 职场文书
2016党员三严三实心得体会
2016/01/15 职场文书
少儿励志名言(80句)
2019/08/14 职场文书