对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 专题三 字符串的基础知识
Mar 19 Python
详解Python函数可变参数定义及其参数传递方式
Aug 02 Python
Python及PyCharm下载与安装教程
Nov 18 Python
详解Python中的内建函数,可迭代对象,迭代器
Apr 29 Python
Python Web框架之Django框架Form组件用法详解
Aug 16 Python
python 利用turtle模块画出没有角的方格
Nov 23 Python
python关于调用函数外的变量实例
Dec 26 Python
Python while循环使用else语句代码实例
Feb 07 Python
Python连接Oracle之环境配置、实例代码及报错解决方法详解
Feb 11 Python
python中def是做什么的
Jun 10 Python
解决Jupyter-notebook不弹出默认浏览器的问题
Mar 30 Python
Python编写可视化界面的全过程(Python+PyCharm+PyQt)
May 17 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
建立动态的WML站点(三)
2006/10/09 PHP
建立动态的WML站点(二)
2006/10/09 PHP
php下几个常用的去空、分组、调试数组函数
2009/02/22 PHP
javascript 清除输入框中的数据
2009/04/13 Javascript
分享20多个很棒的jQuery 文件上传插件或教程
2011/09/04 Javascript
form表单中去掉默认的enter键提交并绑定js方法实现代码
2013/04/01 Javascript
Bootstrap每天必学之下拉菜单
2015/11/25 Javascript
简单实现的JQuery文本框水印插件
2016/06/14 Javascript
浅谈javascript中的加减时间
2016/07/12 Javascript
JavaScript 用fetch 实现异步下载文件功能
2017/07/21 Javascript
vue-router路由与页面间导航实例解析
2017/11/07 Javascript
JS动画定时器知识总结
2018/03/23 Javascript
js实现延迟加载的几种方法详解
2019/01/19 Javascript
React优化子组件render的使用
2019/05/12 Javascript
基于脚手架创建Vue项目实现步骤详解
2020/08/03 Javascript
[54:53]2014 DOTA2国际邀请赛中国区预选赛 LGD-GAMING VS CIS 第二场
2014/05/23 DOTA
Python基础之函数用法实例详解
2014/09/10 Python
在arcgis使用python脚本进行字段计算时是如何解决中文问题的
2015/10/18 Python
python设计模式大全
2016/06/27 Python
tensorflow构建BP神经网络的方法
2018/03/12 Python
python如何对实例属性进行类型检查
2018/03/20 Python
Python小白必备的8个最常用的内置函数(推荐)
2019/04/03 Python
使用TensorFlow-Slim进行图像分类的实现
2019/12/31 Python
jupyter notebook参数化运行python方式
2020/04/10 Python
CSS3实现曲线阴影和翘边阴影
2016/05/03 HTML / CSS
美国一家专业的太阳镜网上零售商:Solstice太阳镜
2016/07/25 全球购物
T3官网:头发造型工具
2019/12/26 全球购物
工程造价自荐信
2013/10/09 职场文书
法务专员岗位职责
2014/01/02 职场文书
学校地质灾害防治方案
2014/06/10 职场文书
公司总经理岗位职责范本
2014/08/15 职场文书
领导干部“四风”问题批评与自我批评材料
2014/09/24 职场文书
护理见习报告范文
2014/11/03 职场文书
2015年法律事务部工作总结
2015/07/27 职场文书
详解MySQL中timestamp和datetime时区问题导致做DTS遇到的坑
2021/12/06 MySQL
python读取mat文件生成h5文件的实现
2022/07/15 Python