pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python正则表达式匹配ip地址实例
Oct 09 Python
Python SQLite3简介
Feb 22 Python
基于Python List的赋值方法
Jun 23 Python
python tornado微信开发入门代码
Aug 24 Python
python实现flappy bird游戏
Dec 24 Python
Python利用字典破解WIFI密码的方法
Feb 27 Python
使用python socket分发大文件的实现方法
Jul 08 Python
Python实现串口通信(pyserial)过程解析
Sep 25 Python
wxpython绘制圆角窗体
Nov 18 Python
python3 循环读取excel文件并写入json操作
Jul 14 Python
Python同时迭代多个序列的方法
Jul 28 Python
Python如何输出百分比
Jul 31 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
PHP多线程抓取网页实现代码
2010/07/22 PHP
Linux php 中文乱码的快速解决方法
2016/05/13 PHP
PHP获取数组中单列值的方法
2017/06/10 PHP
php双层循环(九九乘法表)
2017/10/23 PHP
Yii框架Session与Cookie使用方法示例
2019/10/14 PHP
RR vs IO BO3 第二场2.13
2021/03/10 DOTA
Exjs 入门篇
2010/04/07 Javascript
PHP 与 js的通信(via ajax,json)
2010/11/16 Javascript
Jquery阻止事件冒泡 event.stopPropagation
2011/12/11 Javascript
javascript搜索框点击文字消失失焦时文本出现
2014/09/18 Javascript
解读Bootstrap v4 sass设计
2016/05/29 Javascript
jquery树形菜单效果的简单实例
2016/06/06 Javascript
AngularJS解决ng界面长表达式(ui-set)的方法分析
2016/11/07 Javascript
详解打造 Vue.js 可复用组件
2017/03/24 Javascript
Nodejs调用WebService的示例代码
2017/09/29 NodeJs
jQuery实现获取选中复选框的值实例详解
2018/06/28 jQuery
vue-rx的初步使用教程
2018/09/21 Javascript
原生JS实现的自动轮播图功能详解
2018/12/28 Javascript
[04:28]2014DOTA2国际邀请赛 采访小兔子LGD挺进钥匙体育馆
2014/07/14 DOTA
Python实现统计英文单词个数及字符串分割代码
2015/05/28 Python
Python3.6简单反射操作示例
2018/06/14 Python
win10下python3.5.2和tensorflow安装环境搭建教程
2018/09/19 Python
解决Django删除migrations文件夹中的文件后出现的异常问题
2019/08/31 Python
python3 mmh3安装及使用方法
2019/10/09 Python
python字典排序的方法
2019/10/12 Python
Pandas 解决dataframe的一列进行向下顺移问题
2019/12/27 Python
HTC VIVE美国官网:VR虚拟现实眼镜
2018/02/13 全球购物
开普敦通行证:Cape Town Pass
2019/07/18 全球购物
Perfume’s Club澳大利亚官网:西班牙领先的在线美容店
2021/02/01 全球购物
购房公证委托书(2014版)
2014/09/12 职场文书
2014年教育工作总结
2014/11/26 职场文书
学历证明样本
2015/06/16 职场文书
创业计划书之美甲店
2019/09/20 职场文书
导游词之京东大峡谷旅游区
2019/10/29 职场文书
mybatis中注解与xml配置的对应关系和对比分析
2021/08/04 Java/Android
python turtle绘图
2022/05/04 Python