对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取Linux系统下的本机IP地址代码分享
Nov 07 Python
python判断字符串编码的简单实现方法(使用chardet)
Jul 01 Python
python anaconda 安装 环境变量 升级 以及特殊库安装的方法
Jun 21 Python
django中的setting最佳配置小结
Nov 21 Python
django使用LDAP验证的方法示例
Dec 10 Python
带你认识Django
Jan 15 Python
Tensorflow: 从checkpoint文件中读取tensor方式
Feb 10 Python
django使用F方法更新一个对象多个对象字段的实现
Mar 28 Python
django为Form生成的label标签添加class方式
May 20 Python
手把手教你用Django执行原生SQL的方法
Feb 18 Python
matplotlib之pyplot模块之标题(title()和suptitle())
Feb 22 Python
Python基础之常用库常用方法整理
Apr 30 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
索尼SONY ICF-SW7600GR电路分析与改良
2021/03/02 无线电
基于mysql的bbs设计(三)
2006/10/09 PHP
PHP常用的排序和查找算法
2015/08/06 PHP
laravel 操作数据库常用函数的返回值方法
2019/10/11 PHP
JavaScript XML操作 封装类
2009/07/01 Javascript
JavaScript字符串String和Array操作的有趣方法
2012/12/18 Javascript
js实时获取系统当前时间实例代码
2013/06/28 Javascript
JS实现表单中checkbox对勾选中增加边框显示效果
2015/08/21 Javascript
深入浅析JS的数组遍历方法(推荐)
2016/06/15 Javascript
JavaScript 继承详解(六)
2016/10/11 Javascript
浅谈jquery采用attr修改form表单enctype不起作用的问题
2016/11/25 Javascript
js实现百度搜索提示框
2017/02/05 Javascript
使用webpack搭建vue项目实现脚手架功能
2019/03/15 Javascript
微信小程序如何调用图片接口API并居中显示
2019/06/29 Javascript
vue实现设置载入动画和初始化页面动画效果
2019/10/28 Javascript
使用vue3重构拼图游戏的实现示例
2021/01/25 Vue.js
探究Python中isalnum()方法的使用
2015/05/18 Python
利用Python如何生成随机密码
2016/04/20 Python
Python实现视频下载功能
2017/03/14 Python
python 实现tar文件压缩解压的实例详解
2017/08/20 Python
Python 12306抢火车票脚本 Python京东抢手机脚本
2018/02/06 Python
修改python plot折线图的坐标轴刻度方法
2018/12/13 Python
python+selenium实现自动化百度搜索关键词
2019/06/03 Python
如何基于Python批量下载音乐
2019/11/11 Python
jupyter notebook 重装教程
2020/04/16 Python
html5 css3 动态气泡按钮实例演示
2012/12/02 HTML / CSS
英国高级健康和美容产品零售商:Life and Looks
2019/08/01 全球购物
什么是触发器(trigger)? 触发器有什么作用?
2013/09/18 面试题
计算机专业推荐信范文
2013/11/27 职场文书
工会主席事迹材料
2014/06/03 职场文书
2014年食品安全工作总结
2014/12/04 职场文书
员工聘用合同范本
2015/09/21 职场文书
学习社交礼仪心得体会
2016/01/22 职场文书
MySQL数据库压缩版本安装与配置详细教程
2021/05/21 MySQL
Python循环之while无限迭代
2022/04/30 Python
Spring Boot优化后启动速度快到飞起技巧示例
2022/07/23 Java/Android