pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python random模块(获取随机数)常用方法和使用例子
May 13 Python
Python格式化css文件的方法
Mar 10 Python
利用Python获取操作系统信息实例
Sep 02 Python
Python实现类似比特币的加密货币区块链的创建与交易实例
Mar 20 Python
Python中将dataframe转换为字典的实例
Apr 13 Python
python抽取指定url页面的title方法
May 11 Python
Python网页正文转换语音文件的操作方法
Dec 09 Python
解决django前后端分离csrf验证的问题
Feb 03 Python
详解python调用cmd命令三种方法
Jul 08 Python
Python-Tkinter Text输入内容在界面显示的实例
Jul 12 Python
python 直接赋值和copy的区别详解
Aug 07 Python
Python中os模块的简单使用及重命名操作
Apr 17 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
php mysql索引问题
2008/06/07 PHP
PHP 如何利用phpexcel导入数据库
2013/08/24 PHP
php版微信公众平台实现预约提交后发送email的方法
2016/09/26 PHP
Laravel5.5+ 使用API Resources快速输出自定义JSON方法详解
2020/04/06 PHP
javascript 播放器 控制
2007/01/22 Javascript
JavaScript中的prototype使用说明
2010/04/13 Javascript
javascript之querySelector和querySelectorAll使用说明
2011/10/09 Javascript
15个款优秀的 jQuery 图片特效插件推荐
2011/11/21 Javascript
javascript学习笔记(十三) js闭包介绍(转)
2012/06/20 Javascript
cookie 最近浏览记录(中文escape转码)具体实现
2013/06/08 Javascript
浅析js中取绝对值的2种方法
2013/07/09 Javascript
JavaScript阻止事件冒泡示例分享
2014/12/28 Javascript
jQuery仿Flash上下翻动的中英文导航菜单实例
2015/03/10 Javascript
jQuery结合CSS制作动态的下拉菜单
2015/10/27 Javascript
jQuery实现的动态文字变化输出效果示例【附演示与demo源码下载】
2017/03/24 jQuery
用js屏蔽被http劫持的浮动广告实现方法
2017/08/10 Javascript
开发用到的js封装方法(20种)
2018/10/12 Javascript
vue中实现上传文件给后台实例详解
2019/08/22 Javascript
Vue.js计算机属性computed和methods方法详解
2019/10/12 Javascript
vue中使用腾讯云Im的示例
2020/10/23 Javascript
python中字符串内置函数的用法总结
2018/09/13 Python
pycharm远程开发项目的实现步骤
2019/01/20 Python
8段用于数据清洗Python代码(小结)
2019/10/31 Python
python飞机大战 pygame游戏创建快速入门详解
2019/12/17 Python
python实现横向拼接图片
2020/03/23 Python
TensorFlow的环境配置与安装方法
2021/02/20 Python
美国孕妇装购物网站:Motherhood Maternity
2019/09/22 全球购物
心理咨询专业自荐信
2014/07/07 职场文书
英文演讲稿开场白
2014/08/25 职场文书
群众路线教育实践活动剖析材料
2014/09/30 职场文书
给女朋友道歉的话大全
2015/01/20 职场文书
考察邀请函范文
2015/01/31 职场文书
JavaScript控制台的更多功能
2021/04/28 Javascript
Springboot配置suffix指定mvc视图的后缀方法
2021/07/03 Java/Android
php 文件上传至OSS及删除远程阿里云OSS文件
2021/07/04 PHP
python 闭包函数详细介绍
2022/04/19 Python