对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中正则表达式详解
May 17 Python
用生成器来改写直接返回列表的函数方法
May 25 Python
Python实现购物车功能的方法分析
Nov 10 Python
Win7下Python与Tensorflow-CPU版开发环境的安装与配置过程
Jan 04 Python
利用Anaconda简单安装scrapy框架的方法
Jun 13 Python
三步实现Django Paginator分页的方法
Jun 11 Python
Python图片的横坐标汉字实例
Dec 04 Python
Pycharm使用远程linux服务器conda/python环境在本地运行的方法(图解))
Dec 09 Python
python为Django项目上的每个应用程序创建不同的自定义404页面(最佳答案)
Mar 09 Python
用 Python 制作地球仪的方法
Apr 24 Python
详解scrapy内置中间件的顺序
Sep 28 Python
Python爬虫数据的分类及json数据使用小结
Mar 29 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
PHP用mysql数据库存储session的代码
2010/03/05 PHP
PHP中mb_convert_encoding与iconv函数的深入解析
2013/06/21 PHP
Smarty变量调节器失效的解决办法
2014/08/20 PHP
一个完整的php文件上传类实例讲解
2015/10/27 PHP
php while循环控制的简单实例
2016/05/30 PHP
PHP-X系列教程之内置函数的使用示例
2017/10/16 PHP
jQuery 源代码显示控件 (Ajax加载方式).
2009/05/18 Javascript
extjs 列表框(multiselect)的动态添加列表项的方法
2009/07/31 Javascript
Array.prototype.slice 使用扩展
2010/06/09 Javascript
js不能获取隐藏的div的宽度只能先显示后获取
2014/09/04 Javascript
jQuery实现菜单式图片滑动切换
2015/03/14 Javascript
JavaScript中实现无缝滚动、分享到侧边栏实例代码
2016/04/06 Javascript
JS 对java返回的json格式的数据处理方法
2016/12/05 Javascript
JS查找字符串中出现最多的字符及个数统计
2017/02/04 Javascript
thinkjs之页面跳转同步异步操作
2017/02/05 Javascript
原生JS改变透明度实现轮播效果
2017/03/24 Javascript
JavaScript实现图片无缝滚动效果
2017/07/07 Javascript
Vue学习笔记进阶篇之vue-router安装及使用方法
2017/07/19 Javascript
React 无状态组件(Stateless Component) 与高阶组件
2018/08/14 Javascript
JavaScript 高性能数组去重的方法
2018/09/20 Javascript
微信小程序入口场景的问题集合与相关解决方法
2019/06/26 Javascript
Python获取任意xml节点值的方法
2015/05/05 Python
python使用nntp读取新闻组内容的方法
2015/05/08 Python
pandas 快速处理 date_time 日期格式方法
2018/11/12 Python
Django restframework 源码分析之认证详解
2019/02/22 Python
PYTHON EVAL的用法及注意事项解析
2019/09/06 Python
Python实现剪刀石头布小游戏(与电脑对战)
2019/12/31 Python
Pandas替换及部分替换(replace)实现流程详解
2020/10/12 Python
Python Selenium库的基本使用教程
2021/01/04 Python
sklearn中的交叉验证的实现(Cross-Validation)
2021/02/22 Python
机关中层领导干部群众路线教育实践活动个人对照检查材料
2014/09/24 职场文书
工作失误检讨书
2015/01/26 职场文书
乌镇导游词
2015/02/02 职场文书
教师节校长致辞
2015/07/31 职场文书
CSS变量实现主题切换的方法
2021/06/23 HTML / CSS
MySql统计函数COUNT的具体使用详解
2022/08/14 MySQL