pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
基于Python的文件类型和字符串详解
Dec 21 Python
Python 中Pickle库的使用详解
Feb 24 Python
浅谈python之新式类
Aug 12 Python
python得到qq句柄,并显示在前台的方法
Oct 14 Python
在unittest中使用 logging 模块记录测试数据的方法
Nov 30 Python
python获取微信企业号打卡数据并生成windows计划任务
Apr 30 Python
python导入坐标点的具体操作
May 10 Python
Python实现使用dir获取类的方法列表
Dec 24 Python
python 轮询执行某函数的2种方式
May 03 Python
Python通过fnmatch模块实现文件名匹配
Sep 30 Python
详解java调用python的几种用法(看这篇就够了)
Dec 10 Python
Python爬虫入门教程02之笔趣阁小说爬取
Jan 24 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
用php实现让页面只能被百度gogole蜘蛛访问的方法
2009/12/29 PHP
php生成静态文件的多种方法分享
2012/07/17 PHP
PHP遍历某个目录下的所有文件和子文件夹的实现代码
2013/06/28 PHP
浅谈PDO的rowCount函数
2015/06/18 PHP
PHP 枚举类型的管理与设计知识点总结
2020/02/13 PHP
JavaScript 计算当天是本年本月的第几周
2009/03/22 Javascript
PHP abstract与interface之间的区别
2013/11/11 Javascript
JS中产生标识符方式的演变
2015/06/12 Javascript
黑帽seo劫持程序,js劫持搜索引擎代码
2015/09/15 Javascript
js实现图片无缝滚动特效
2020/03/19 Javascript
JS简单去除数组中重复项的方法
2016/09/13 Javascript
原生js实现简单的模态框示例
2017/09/08 Javascript
nodejs超出最大的调用栈错误问题
2017/12/27 NodeJs
nodejs爬虫初试superagent和cheerio
2018/03/05 NodeJs
微信小程序scroll-x失效的完美解决方法
2018/07/18 Javascript
vue中格式化时间过滤器代码实例
2019/04/17 Javascript
使用JS判断页面是首次被加载还是刷新
2019/05/26 Javascript
[47:08]OG vs INfamous 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/17 DOTA
python定向爬取淘宝商品价格
2018/02/27 Python
python执行系统命令后获取返回值的几种方式集合
2018/05/12 Python
Django中celery执行任务结果的保存方法
2019/07/12 Python
通过实例了解Python str()和repr()的区别
2020/01/17 Python
python GUI库图形界面开发之PyQt5浏览器控件QWebEngineView详细使用方法
2020/02/26 Python
Softmax函数原理及Python实现过程解析
2020/05/22 Python
巧用CSS3 border实现图片遮罩效果代码
2012/04/09 HTML / CSS
详解HTML5常用的语义化标签
2019/09/27 HTML / CSS
大学生军训自我评价分享
2013/11/09 职场文书
酒店总经理欢迎词
2014/01/08 职场文书
《梅兰芳学艺》教学反思
2014/02/24 职场文书
股份合作协议书范本
2014/04/14 职场文书
超市商业计划书
2014/05/04 职场文书
2014年乡镇妇联工作总结
2014/12/02 职场文书
zabbix监控mysql的实例方法
2021/06/02 MySQL
React forwardRef的使用方法及注意点
2021/06/13 Javascript
python创建字典及相关管理操作
2022/04/13 Python
Alexa停服!网站排名将何去何从?目前还没有替代品。
2022/04/15 杂记