pytorch 中的重要模块化接口nn.Module的使用


Posted in Python onApril 02, 2020

torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络
nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法

查看源码

初始化部分:

def __init__(self):
  self._backend = thnn_backend
  self._parameters = OrderedDict()
  self._buffers = OrderedDict()
  self._backward_hooks = OrderedDict()
  self._forward_hooks = OrderedDict()
  self._forward_pre_hooks = OrderedDict()
  self._state_dict_hooks = OrderedDict()
  self._load_state_dict_pre_hooks = OrderedDict()
  self._modules = OrderedDict()
  self.training = True

属性解释:

  • _parameters:字典,保存用户直接设置的 Parameter
  • _modules:子 module,即子类构造函数中的内容
  • _buffers:缓存
  • _backward_hooks与_forward_hooks:钩子技术,用来提取中间变量
  • training:判断值来决定前向传播策略

方法定义:

def forward(self, *input):
 raise NotImplementedError

没有实际内容,用于被子类的 forward() 方法覆盖

且 forward 方法在 __call__ 方法中被调用:

def __call__(self, *input, **kwargs):
 for hook in self._forward_pre_hooks.values():
    hook(self, input)
  if torch._C._get_tracing_state():
    result = self._slow_forward(*input, **kwargs)
  else:
    result = self.forward(*input, **kwargs)
  ...
  ...

对于自己定义的网络,需要注意以下几点:

1)需要继承nn.Module类,并实现forward方法,只要在nn.Module的子类中定义forward方法,backward函数就会被自动实现(利用autograd机制)
2)一般把网络中可学习参数的层放在构造函数中__init__(),没有可学习参数的层如Relu层可以放在构造函数中,也可以不放在构造函数中(在forward函数中使用nn.Functional)
3)在forward中可以使用任何Variable支持的函数,在整个pytorch构建的图中,是Variable在流动,也可以使用for,print,log等
4)基于nn.Module构建的模型中,只支持mini-batch的Variable的输入方式,如,N*C*H*W

代码示例:

class LeNet(nn.Module):
  def __init__(self):
    # nn.Module的子类函数必须在构造函数中执行父类的构造函数
    super(LeNet, self).__init__() # 等价与nn.Module.__init__()

    # nn.Conv2d返回的是一个Conv2d class的一个对象,该类中包含forward函数的实现
    # 当调用self.conv1(input)的时候,就会调用该类的forward函数
    self.conv1 = nn.Conv2d(1, 6, (5, 5)) # output (N, C_{out}, H_{out}, W_{out})`
    self.conv2 = nn.Conv2d(6, 16, (5, 5))
    self.fc1 = nn.Linear(256, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    # F.max_pool2d的返回值是一个Variable, input:(10,1,28,28) ouput:(10, 6, 12, 12)
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # input:(10, 6, 12, 12)  output:(10,6,4,4)
    x = F.max_pool2d(F.relu(self.conv2(x)), (2, 2))
    # 固定样本个数,将其他维度的数据平铺,无论你是几通道,最终都会变成参数, output:(10, 256)
    x = x.view(x.size()[0], -1)
    # 全连接
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.relu(self.fc3(x))

    # 返回值也是一个Variable对象
    return x


def output_name_and_params(net):
  for name, parameters in net.named_parameters():
    print('name: {}, param: {}'.format(name, parameters))


if __name__ == '__main__':
  net = LeNet()
  print('net: {}'.format(net))
  params = net.parameters() # generator object
  print('params: {}'.format(params))
  output_name_and_params(net)

  input_image = torch.FloatTensor(10, 1, 28, 28)

  # 和tensorflow不一样,pytorch中模型的输入是一个Variable,而且是Variable在图中流动,不是Tensor。
  # 这可以从forward中每一步的执行结果可以看出
  input_image = Variable(input_image)

  output = net(input_image)
  print('output: {}'.format(output))
  print('output.size: {}'.format(output.size()))

到此这篇关于pytorch 中的重要模块化接口nn.Module的使用的文章就介绍到这了,更多相关pytorch nn.Module内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python 自动补全(vim)
Nov 30 Python
浅析Python中的join()方法的使用
May 19 Python
python3批量删除豆瓣分组下的好友的实现代码
Jun 07 Python
分享一个简单的python读写文件脚本
Nov 25 Python
python3安装pip3(install pip3 for python 3.x)
Apr 03 Python
python验证码识别教程之利用滴水算法分割图片
Jun 05 Python
Python实现迭代时使用索引的方法示例
Jun 05 Python
在python 不同时区之间的差值与转换方法
Jan 14 Python
如何通过python的fabric包完成代码上传部署
Jul 29 Python
手机使用python操作图片文件(pydroid3)过程详解
Sep 25 Python
python读取ini配置文件过程示范
Dec 23 Python
keras输出预测值和真实值方式
Jun 27 Python
python递归函数求n的阶乘,优缺点及递归次数设置方式
Apr 02 #Python
PyTorch中的C++扩展实现
Apr 02 #Python
python实现将列表中各个值快速赋值给多个变量
Apr 02 #Python
Python运行提示缺少模块问题解决方案
Apr 02 #Python
Pycharm配置PyQt5环境的教程
Apr 02 #Python
Python无头爬虫下载文件的实现
Apr 02 #Python
linux 下selenium chrome使用详解
Apr 02 #Python
You might like
2019年中国咖啡业现状与发展趋势
2021/03/04 咖啡文化
初次接触php抽象工厂模式(Elgg)
2010/03/21 PHP
PHP COOKIE及时生效的方法介绍
2014/02/14 PHP
php实现检查文章是否被百度收录
2015/01/27 PHP
phpinfo() 中 Local Value(局部变量)Master Value(主变量) 的区别
2016/02/03 PHP
php结合ajax实现手机发红包的案例
2016/10/13 PHP
详解如何实现Laravel的服务容器的方法示例
2019/04/15 PHP
用jquery设置按钮的disabled属性的实现代码
2010/11/28 Javascript
jquery中eq和get的区别与使用方法
2011/04/14 Javascript
JavaScript实现动态删除列表框值的方法
2015/08/12 Javascript
javascript日期比较方法实例分析
2016/06/17 Javascript
IOS中safari下的select下拉菜单文字过长不换行的解决方法
2016/09/26 Javascript
javascript ASCII和Hex互转的实现方法
2016/12/27 Javascript
javascript表单正则应用
2017/02/04 Javascript
jQuery实现左右滑动的toggle方法
2018/03/03 jQuery
浅析前端路由简介以及vue-router实现原理
2018/06/01 Javascript
nodejs之koa2请求示例(GET,POST)
2018/08/07 NodeJs
Vue.js组件高级特性实例详解
2018/12/24 Javascript
详解ES7 Decorator 入门解析
2019/02/18 Javascript
微信小程序中转义字符的处理方法
2019/03/28 Javascript
Python使用plotly绘制数据图表的方法
2017/07/18 Python
VTK与Python实现机械臂三维模型可视化详解
2017/12/13 Python
python实时获取外部程序输出结果的方法
2019/01/12 Python
Appium+python自动化怎么查看程序所占端口号和IP
2019/06/14 Python
Python+PyQT5的子线程更新UI界面的实例
2019/06/14 Python
python实现移动木板小游戏
2020/10/09 Python
美国首屈一指的礼品篮供应商:GiftTree
2018/01/06 全球购物
NBA欧洲商店(西班牙):NBA Europe Store ES
2019/04/16 全球购物
幼师专业毕业生自荐信
2013/09/29 职场文书
维修工先进事迹
2014/05/29 职场文书
经典团队口号
2014/06/06 职场文书
学生个人评语大全
2015/01/04 职场文书
房贷工资证明范本
2015/06/12 职场文书
庆七一主持词
2015/06/29 职场文书
MongoDB数据库的安装步骤
2021/06/18 MongoDB
Nginx实现负载均衡的项目实践
2022/03/18 Servers