pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
在Python中操作字典之fromkeys()方法的使用
May 21 Python
解决python xlrd无法读取excel文件的问题
Dec 25 Python
python实现播放音频和录音功能示例代码
Dec 30 Python
Numpy之random函数使用学习
Jan 29 Python
Python Django框架实现应用添加logging日志操作示例
May 17 Python
pycharm快捷键汇总
Feb 14 Python
在python中使用pymysql往mysql数据库中插入(insert)数据实例
Mar 02 Python
python 解决mysql where in 对列表(list,,array)问题
Jun 06 Python
django美化后台django-suit的安装配置操作
Jul 12 Python
Django静态文件加载失败解决方案
Aug 26 Python
python 实现压缩和解压缩的示例
Sep 22 Python
详解matplotlib绘图样式(style)初探
Feb 03 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
24条货真价实的PHP代码优化技巧
2016/07/28 PHP
PHPStrom 新建FTP项目以及在线操作教程
2016/10/16 PHP
解决thinkPHP 5 nginx 部署时,只跳转首页的问题
2019/10/16 PHP
根据分辨率不同,调用不同的css文件
2006/07/07 Javascript
jquery的颜色选择插件实例代码
2008/10/02 Javascript
给Flash加一个超链接(推荐使用透明层)兼容主流浏览器
2013/06/09 Javascript
js实现页面转发功能示例代码
2013/08/05 Javascript
js实现简单登录功能的实例代码
2013/11/09 Javascript
jQuery操作cookie方法实例教程
2014/11/25 Javascript
jquery实现键盘左右翻页特效
2015/04/30 Javascript
js实现登陆遮罩效果的方法
2015/07/28 Javascript
jQuery实现的调整表格行tr上下顺序
2016/01/10 Javascript
jquery mobile开发常见问题分析
2016/01/21 Javascript
JS简单封装的图片无缝滚动效果示例【测试可用】
2017/03/22 Javascript
JavaScript中使用参数个数实现重载功能
2017/09/01 Javascript
Vue基于NUXT的SSR详解
2017/10/24 Javascript
JavaScript编程设计模式之观察者模式(Observer Pattern)实例详解
2017/10/25 Javascript
js获取html页面代码中图片地址的实现代码
2018/03/05 Javascript
详解vue-router 初始化时做了什么
2018/06/11 Javascript
VUE前后端学习tab写法实例
2019/08/06 Javascript
Weex开发之WEEX-EROS开发踩坑(小结)
2019/10/16 Javascript
如何在postman中添加cookie信息步骤解析
2020/06/30 Javascript
[02:53]DOTA2英雄基础教程 山岭巨人小小
2013/12/09 DOTA
[01:10]DOTA2次级职业联赛 - U5战队宣传片
2014/12/01 DOTA
[48:30]LGD vs infamous Supermajor小组赛D组 BO3 第一场 6.3
2018/06/04 DOTA
python3 实现的对象与json相互转换操作示例
2019/08/17 Python
tensorflow入门:TFRecordDataset变长数据的batch读取详解
2020/01/20 Python
tensorflow使用CNN分析mnist手写体数字数据集
2020/06/17 Python
使用HTML和CSS3绘制基本卡通图案的示例分享
2015/11/06 HTML / CSS
HTC VIVE美国官网:VR虚拟现实眼镜
2018/02/13 全球购物
个人简历自我鉴定
2013/10/11 职场文书
学校后勤人员职责
2013/12/27 职场文书
微笑面对生活演讲稿
2014/05/13 职场文书
乡镇党建工作汇报材料
2014/08/14 职场文书
单方离婚协议书范本(2014版)
2014/09/30 职场文书
白酒代理协议书范本
2014/10/26 职场文书