对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
让Python代码更快运行的5种方法
Jun 21 Python
python获取元素在数组中索引号的方法
Jul 15 Python
Python实现身份证号码解析
Sep 01 Python
Python机器学习logistic回归代码解析
Jan 17 Python
理论讲解python多进程并发编程
Feb 09 Python
python 获取utc时间转化为本地时间的方法
Dec 31 Python
Python获取当前脚本文件夹(Script)的绝对路径方法代码
Aug 27 Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 Python
keras:model.compile损失函数的用法
Jul 01 Python
利用PyQt5+Matplotlib 绘制静态/动态图的实现代码
Jul 13 Python
python使用matplotlib绘制折线图的示例代码
Sep 22 Python
Python利用Turtle绘制哆啦A梦和小猪佩奇
Apr 04 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
PHP中的串行化变量和序列化对象
2006/09/05 PHP
Thinkphp框架开发移动端接口(2)
2016/08/18 PHP
PHP 多任务秒级定时器的实现方法
2018/05/13 PHP
控制打印时页眉角的代码
2007/02/08 Javascript
JavaScript 保存数组到Cookie的代码
2010/04/14 Javascript
javascript中的注释使用与注意事项小结
2011/09/20 Javascript
使用UglifyJS合并/压缩JavaScript的方法
2012/03/07 Javascript
解决IE6的PNG透明JS插件使用介绍
2013/04/17 Javascript
如何使用jquery easyui创建标签组件
2015/11/18 Javascript
jQuery在ie6下无法设置select选中的解决方法详解
2016/09/20 Javascript
jquery中$.fn和图片滚动效果实现的必备知识总结
2017/04/21 jQuery
详解vue.js2.0父组件点击触发子组件方法
2017/05/10 Javascript
Javascript继承机制详解
2017/05/30 Javascript
在angular 6中使用 less 的实例代码
2018/05/13 Javascript
详解jQuery-each()方法
2019/03/13 jQuery
VUE 实现复制内容到剪贴板的两种方法
2019/04/24 Javascript
vue项目中mock.js的使用及基本用法
2019/05/22 Javascript
vue分页器组件编写方法详解
2019/06/28 Javascript
Vue 设置axios请求格式为form-data的操作步骤
2019/10/29 Javascript
详解Python2.x中对Unicode编码的使用
2015/04/03 Python
详解Python中的join()函数的用法
2015/04/07 Python
使用Python脚本来获取Cisco设备信息的示例
2015/05/04 Python
python 接口_从协议到抽象基类详解
2017/08/24 Python
使用python根据端口号关闭进程的方法
2018/11/06 Python
在python中利用KNN实现对iris进行分类的方法
2018/12/11 Python
使用CSS3制作饼状旋转载入效果的实例
2015/06/23 HTML / CSS
html5实现的便签特效(实战分享)
2013/11/29 HTML / CSS
HTML5 video标签(播放器)学习笔记(一):使用入门
2015/04/24 HTML / CSS
捷克浴室和厨房设备购物网站:SIKO
2018/08/11 全球购物
俄罗斯厨房产品购物网站:COOK HOUSE
2021/03/15 全球购物
《再见了,亲人》教学反思
2014/02/26 职场文书
元宵节晚会主持人串词
2014/03/25 职场文书
党员检讨书
2014/10/13 职场文书
2014小学数学教研组工作总结
2014/12/06 职场文书
Python数据分析入门之数据读取与存储
2021/05/13 Python
sql server 累计求和实现代码
2022/02/28 SQL Server