对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中MySQLdb和torndb模块对MySQL的断连问题处理
Nov 09 Python
Python打造出适合自己的定制化Eclipse IDE
Mar 02 Python
Python实现一个转存纯真IP数据库的脚本分享
May 21 Python
Python多进程库multiprocessing中进程池Pool类的使用详解
Nov 24 Python
python实现定时自动备份文件到其他主机的实例代码
Feb 23 Python
解决nohup重定向python输出到文件不成功的问题
May 11 Python
解决sublime+python3无法输出中文的问题
Dec 12 Python
Python实例方法、类方法、静态方法的区别与作用详解
Mar 25 Python
python实现网站微信登录的示例代码
Sep 18 Python
python使用信号量动态更新配置文件的操作
Apr 01 Python
Python代码需要缩进吗
Jul 01 Python
使用python脚本自动生成K8S-YAML的方法示例
Jul 12 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
PHP与MySQL开发中页面出现乱码的一种解决方法
2007/07/29 PHP
通过PHP的内置函数,通过DES算法对数据加密和解密
2012/06/21 PHP
IIS 7.5 asp Session超时时间设置方法
2017/04/17 PHP
php实现不通过扩展名准确判断文件类型的方法【finfo_file方法与二进制流】
2017/04/18 PHP
总结PHP内存释放以及垃圾回收
2018/03/29 PHP
国外Lightbox v2.03.3 最新版 下载
2007/10/17 Javascript
js 判断所选时间(或者当前时间)是否在某一时间段的实现代码
2015/09/05 Javascript
JS实现网页上随机产生超链接地址的方法
2015/11/09 Javascript
js仿淘宝和百度文库的评分功能
2016/05/15 Javascript
浅谈JavaScript 数据属性和访问器属性
2016/09/01 Javascript
nodejs模块nodemailer基本使用-邮件发送示例(支持附件)
2017/03/28 NodeJs
让div运动起来 js实现缓动效果
2017/07/06 Javascript
vue中keep-alive的用法及问题描述
2018/05/15 Javascript
JavaScript设计模式之职责链模式应用示例
2018/08/07 Javascript
如何实现小程序tab栏下划线动画效果
2019/05/18 Javascript
Nuxt.js实现一个SSR的前端博客的示例代码
2019/09/06 Javascript
Echarts实现多条折线可拖拽效果
2019/12/19 Javascript
JavaScript this使用方法图解
2020/02/04 Javascript
Node.js 在本地生成日志文件的方法
2020/02/07 Javascript
vue+openlayers绘制省市边界线
2020/12/24 Vue.js
Python logging模块学习笔记
2014/05/24 Python
python使用在线API查询IP对应的地理位置信息实例
2014/06/01 Python
Python实现列表转换成字典数据结构的方法
2016/03/11 Python
Python字符串格式化的方法(两种)
2017/09/19 Python
快速解决PyCharm无法引用matplotlib的问题
2018/05/24 Python
python 自动批量打开网页的示例
2019/02/21 Python
Python HTMLTestRunner如何下载生成报告
2020/09/04 Python
html5 实现客户端验证上传文件的大小(简单实例)
2016/05/15 HTML / CSS
英国乐购杂货:Tesco Groceries
2018/11/29 全球购物
最好的商品表达自己:Cafepress
2019/09/04 全球购物
儿科护士自我鉴定
2013/10/14 职场文书
学生检讨书怎么写?
2014/10/10 职场文书
党课主持词大全
2015/06/30 职场文书
2016年艾滋病宣传活动总结
2016/04/01 职场文书
创业计划书之家教中心
2019/09/25 职场文书
SQL CASE 表达式的具体使用
2022/03/21 SQL Server