对Pytorch中nn.ModuleList 和 nn.Sequential详解


Posted in Python onAugust 18, 2019

简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。

需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:

nn.ModuleList(
      [nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
       range(nstack)])

其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!

摘录自

nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:

class LinearNet(nn.Module):
 def __init__(self, input_size, num_layers, layers_size, output_size):
   super(LinearNet, self).__init__()
 
   self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
   self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
   self.linears.append(nn.Linear(layers_size, output_size)

nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:

class Flatten(nn.Module):
 def forward(self, x):
  N, C, H, W = x.size() # read in N, C, H, W
  return x.view(N, -1)
 
simple_cnn = nn.Sequential(
      nn.Conv2d(3, 32, kernel_size=7, stride=2),
      nn.ReLU(inplace=True),
      Flatten(), 
      nn.Linear(5408, 10),
     )

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.

以上这篇对Pytorch中nn.ModuleList 和 nn.Sequential详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
从Python的源码来解析Python下的freeblock
May 11 Python
python获取外网ip地址的方法总结
Jul 02 Python
详解Python的Django框架中Manager方法的使用
Jul 21 Python
分享Python字符串关键点
Dec 13 Python
Python反射用法实例简析
Dec 22 Python
Sanic框架请求与响应实例分析
Jul 16 Python
python用BeautifulSoup库简单爬虫实例分析
Jul 30 Python
python3.x 生成3维随机数组实例
Nov 28 Python
Python3 io文本及原始流I/O工具用法详解
Mar 23 Python
对python中各个response的使用说明
Mar 28 Python
Python如何定义有可选参数的元类
Jul 31 Python
Python 删除List元素的三种方法remove、pop、del
Nov 16 Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
You might like
PHP与javascript的两种交互方式
2006/10/09 PHP
PHP正则匹配操作简单示例【preg_match_all应用】
2017/07/10 PHP
CI框架(CodeIgniter)实现的导入、导出数据操作示例
2018/05/24 PHP
jQuery源码分析之Callbacks详解
2015/03/13 Javascript
jQuery实现ajax的叠加和停止(终止ajax请求)
2016/08/08 Javascript
JS搜狐面试题分析
2016/12/16 Javascript
细说webpack源码之compile流程-rules参数处理技巧(1)
2017/12/26 Javascript
vue短信验证性能优化如何写入localstorage中
2018/04/25 Javascript
微信小程序获取用户信息并保存登录状态详解
2019/05/10 Javascript
微信打开网址添加在浏览器中打开提示的办法
2019/05/20 Javascript
pm2启动ssr失败的解决方法
2019/06/29 Javascript
在Node.js中将SVG图像转换为PNG,JPEG,TIFF,WEBP和HEIF格式的方法
2019/08/22 Javascript
Vue使用v-viewer实现图片预览
2020/10/21 Javascript
[30:51]DOTA2上海特级锦标赛主赛事日 - 3 胜者组第二轮#1Liquid VS MVP.Phx第一局
2016/03/04 DOTA
[33:15]2018DOTA2亚洲邀请赛3月30日 小组赛B组 VP VS Mineski
2018/03/31 DOTA
python教程之用py2exe将PY文件转成EXE文件
2014/06/12 Python
python实现每次处理一个字符的三种方法
2014/10/09 Python
Python的Django框架中设置日期和字段可选的方法
2015/07/17 Python
在Django的上下文中设置变量的方法
2015/07/20 Python
Python教程之全局变量用法
2016/06/27 Python
Python中属性和描述符的正确使用
2016/08/23 Python
Python 登录网站详解及实例
2017/04/11 Python
Python制作刷网页流量工具
2017/04/23 Python
python中模块查找的原理与方法详解
2017/08/11 Python
Python3 queue队列模块详细介绍
2018/01/05 Python
Python获取指定字符前面的所有字符方法
2018/05/02 Python
利用Pyhton中的requests包进行网页访问测试的方法
2018/12/26 Python
Python 微信爬虫完整实例【单线程与多线程】
2019/07/06 Python
英国最大的在线蜡烛商店:Candles Direct
2019/03/26 全球购物
道德与公民自我评价
2015/03/09 职场文书
停发工资证明范本
2015/06/12 职场文书
餐馆开业致辞
2015/08/01 职场文书
幼儿园2016圣诞节活动总结
2016/03/31 职场文书
普希金诗歌赏析(6首)
2019/08/22 职场文书
MySQL触发器的使用
2021/05/24 MySQL
vue如何清除浏览器历史栈
2022/05/25 Vue.js