对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
JS设计模式之责任链模式实例详解
Feb 03 Python
python调用百度REST API实现语音识别
Aug 30 Python
python调用百度地图WEB服务API获取地点对应坐标值
Jan 16 Python
python调用c++ ctype list传数组或者返回数组的方法
Feb 13 Python
django的分页器Paginator 从django中导入类
Jul 25 Python
Python BeautifulSoup [解决方法] TypeError: list indices must be integers or slices, not str
Aug 07 Python
Python使用scipy模块实现一维卷积运算示例
Sep 05 Python
解析Python3中的Import
Oct 13 Python
Python操作SQLite/MySQL/LMDB数据库的方法
Nov 07 Python
Python 字符串处理特殊空格\xc2\xa0\t\n Non-breaking space
Feb 23 Python
Python读取配置文件(config.ini)以及写入配置文件
Apr 08 Python
Python Pandas数据分析工具用法实例
Nov 05 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
php实现mysql数据库备份类
2008/03/20 PHP
PHP三层结构(上) 简单三层结构
2010/07/04 PHP
PHP开发不能违背的安全规则 过滤用户输入
2011/05/01 PHP
深入PHP curl参数的详解
2013/06/17 PHP
JavaScript创建命名空间的5种写法
2014/06/24 PHP
php简单实现MVC
2015/02/05 PHP
PHP实现的登录,注册及密码修改功能分析
2016/11/25 PHP
php从数据库读取数据,并以json格式返回数据的方法
2018/08/21 PHP
你必须知道的JavaScript 中字符串连接的性能的一些问题
2013/05/07 Javascript
一款jquery特效编写的大度宽屏焦点图切换特效的实例代码
2013/08/05 Javascript
js日期、星座的级联显示代码
2014/01/23 Javascript
编写高性能Javascript代码的N条建议
2015/10/12 Javascript
学JavaScript七大注意事项【必看】
2016/05/04 Javascript
JS弹出窗口的运用与技巧大全
2016/11/01 Javascript
详解vue.js全局组件和局部组件
2017/04/10 Javascript
promise处理多个相互依赖的异步请求(实例讲解)
2017/08/03 Javascript
js链表操作(实例讲解)
2017/08/29 Javascript
不使用 JS 匿名函数理由
2017/11/17 Javascript
vue 项目如何引入微信sdk接口的方法
2017/12/18 Javascript
Node.js log4js日志管理详解
2018/07/31 Javascript
boostrap模态框二次弹出清空原有内容的方法
2018/08/10 Javascript
python3序列化与反序列化用法实例
2015/05/26 Python
详解Django中的form库的使用
2015/07/18 Python
python3判断url链接是否为404的方法
2018/08/10 Python
Python如何实现定时器功能
2020/05/28 Python
如何用Python绘制3D柱形图
2020/09/16 Python
基于Python采集爬取微信公众号历史数据
2020/11/27 Python
css3与html5实现响应式导航菜单(导航栏)效果分享
2014/02/12 HTML / CSS
HTML5之SVG 2D入门5—颜色的表示及定义方式
2013/01/30 HTML / CSS
HTML5新增的8类INPUT输入类型介绍
2015/07/06 HTML / CSS
Zatchels官网:英国剑桥包品牌
2021/01/12 全球购物
简历的自我评价
2014/02/03 职场文书
农村葬礼主持词
2014/03/31 职场文书
环境科学专业求职信
2014/08/04 职场文书
2019公司管理制度
2019/04/19 职场文书
详解Nginx启动失败的几种错误处理
2021/04/01 Servers