对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
解决新django中的path不能使用正则表达式的问题
Dec 18 Python
对python内置map和six.moves.map的区别详解
Dec 19 Python
python从子线程中获得返回值的方法
Jan 30 Python
python 基于TCP协议的套接字编程详解
Jun 29 Python
简单了解python关系(比较)运算符
Jul 08 Python
详解Django模版中加载静态文件配置方法
Jul 21 Python
使用python matplotlib 画图导入到word中如何保证分辨率
Apr 16 Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 Python
python不同系统中打开方法
Jun 23 Python
在Keras中CNN联合LSTM进行分类实例
Jun 29 Python
pytorch使用horovod多gpu训练的实现
Sep 09 Python
Python利用Pillow(PIL)库实现验证码图片的全过程
Oct 04 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
PHP 动态随机生成验证码类代码
2010/04/09 PHP
基于python发送邮件的乱码问题的解决办法
2013/04/25 PHP
基于php下载文件的详解
2013/06/02 PHP
jQuery中:button选择器用法实例
2015/01/04 Javascript
PHP+MySQL+jQuery随意拖动层并即时保存拖动位置实例讲解
2015/10/09 Javascript
js命名空间写法示例
2015/12/18 Javascript
Bootstrap下拉菜单样式
2017/02/07 Javascript
写jQuery插件时的注意点
2017/02/20 Javascript
vue v-on监听事件详解
2017/05/17 Javascript
JS滚动到指定位置导航栏固定顶部
2017/07/03 Javascript
angularJS的radio实现单项二选一的使用方法
2018/02/28 Javascript
JS实现关闭小广告特效
2021/01/29 Javascript
Vue组件模板的几种书写形式(3种)
2020/02/19 Javascript
javascript实现固定侧边栏
2021/02/09 Javascript
[01:04:32]DOTA2-DPC中国联赛 正赛 Aster vs LBZS BO3 第二场 2月23日
2021/03/11 DOTA
python设置检查点简单实现代码
2014/07/01 Python
Python将DataFrame的某一列作为index的方法
2018/04/08 Python
10 行Python 代码实现 AI 目标检测技术【推荐】
2019/06/14 Python
django实现用户注册实例讲解
2019/10/30 Python
Python通过队列来实现进程间通信的示例
2020/10/14 Python
python Zmail模块简介与使用示例
2020/12/19 Python
CSS3 渐变(Gradients)之CSS3 线性渐变
2016/07/08 HTML / CSS
Urban Outfitters英国官网:美国平价服饰品牌
2016/11/25 全球购物
意大利高端时尚买手店:Stefania Mode
2018/03/01 全球购物
Myprotein芬兰官网:欧洲第一运动营养品牌
2019/05/05 全球购物
物流管理专业应届生求职信
2013/11/21 职场文书
采购部主管岗位职责
2014/01/01 职场文书
操行评语大全
2014/04/30 职场文书
超市创意活动方案
2014/08/15 职场文书
优秀纪检干部材料
2014/08/27 职场文书
领导班子整改措施
2014/10/24 职场文书
学生违反校规检讨书
2014/10/28 职场文书
2015年会计工作总结范文
2015/05/26 职场文书
会计主管竞聘书
2015/09/15 职场文书
送给火锅店的创意营销方案!
2019/07/08 职场文书
Python3.8官网文档之类的基础语法阅读
2021/09/04 Python