对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python对指定目录下文件进行批量重命名的方法
Apr 18 Python
python3读取MySQL-Front的MYSQL密码
May 03 Python
pandas 条件搜索返回列表的方法
Oct 30 Python
Python装饰器用法实例分析
Jan 14 Python
利用pandas合并多个excel的方法示例
Oct 10 Python
Python编程快速上手——Excel表格创建乘法表案例分析
Feb 28 Python
python中有函数重载吗
May 28 Python
在keras 中获取张量 tensor 的维度大小实例
Jun 10 Python
六种酷炫Python运行进度条效果的实现代码
Jul 17 Python
python实现三壶谜题的示例详解
Nov 02 Python
Flask response响应的具体使用
Jul 15 Python
python运算符之与用户交互
Apr 13 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
php 高性能书写
2010/12/11 PHP
[原创]php常用字符串输出方法分析(echo,print,printf及sprintf)
2016/07/09 PHP
php求斐波那契数的两种实现方式【递归与递推】
2019/09/09 PHP
php使用curl伪造浏览器访问操作示例
2019/09/30 PHP
hover的用法及live的用法介绍(鼠标悬停效果)
2013/03/29 Javascript
AngularJS控制器继承自另一控制器
2016/05/09 Javascript
jQuery Tags Input Plugin(添加/删除标签插件)详解
2016/06/20 Javascript
Bootstrap风格的WPF样式
2016/12/07 Javascript
vue检测对象和数组的变化分析
2018/06/30 Javascript
浅析JS中什么是自定义react数据验证组件
2018/10/19 Javascript
用VueJS写一个Chrome浏览器插件的实现方法
2019/02/27 Javascript
ES6知识点整理之函数对象参数默认值及其解构应用示例
2019/04/17 Javascript
Python中除法使用的注意事项
2014/08/21 Python
python基于windows平台锁定键盘输入的方法
2015/03/05 Python
Python 自动化表单提交实例代码
2017/06/08 Python
python 中的int()函数怎么用
2017/10/17 Python
python3.6.3安装图文教程 TensorFlow安装配置方法
2020/06/24 Python
Python 经典面试题 21 道【不可错过】
2018/09/21 Python
Python基本数据结构之字典类型dict用法分析
2019/06/08 Python
Python实现封装打包自己写的代码,被python import
2020/07/12 Python
Python 使用 PyQt5 开发的关机小工具分享
2020/07/16 Python
Pyinstaller打包Scrapy项目的实现步骤
2020/09/22 Python
详解Django ORM引发的数据库N+1性能问题
2020/10/12 Python
celery在python爬虫中定时操作实例讲解
2020/11/27 Python
css3 中实现炫酷的loading效果
2019/04/26 HTML / CSS
浅谈pc和移动端的响应式的使用
2019/01/03 HTML / CSS
html5菜单折纸效果
2014/04/22 HTML / CSS
加拿大便宜的隐形眼镜商店:Clearly
2016/09/15 全球购物
学生打架检讨书大全
2014/01/23 职场文书
公司副总经理任命书
2014/06/05 职场文书
国际贸易本科毕业生求职信
2014/09/26 职场文书
入党转正申请报告
2015/05/15 职场文书
小学语文继续教育研修日志
2015/11/13 职场文书
电力企业职工培训心得体会
2016/01/11 职场文书
年终奖金发放管理制度,中小企业适用,拿去救急吧!
2019/07/12 职场文书
MySQL中varchar和char类型的区别
2021/11/17 MySQL