pytorch __init__、forward与__call__的用法小结


Posted in Python onFebruary 27, 2021

1.介绍

当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?

1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的

2)forward是表示一个前向传播,构建网络层的先后运算步骤

3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?

当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。

在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

pytorch __init__、forward与__call__的用法小结

2.代码

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def forward(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def __call__(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x

补充:torch/nn目录结构以及__init__.py

torch/nn目录结构以及init.py

pytorch __init__、forward与__call__的用法小结

torch/nn目录结构

__init__.py:

from .modules import *
#nn.modules  导入modules目录下内容 定义容器modules
from .parameter import Parameter
#nn.Parameter 导入parameter.py  定义parameter
from .parallel import DataParallel
#导入parallel目录下data_parallel.py中的DataParallel类
from . import init
#nn.init   导入init.py   参数初始化
from . import utils
#nn.utils  导入utils目录下内容 官网api下nn.utils下api

对于backends, functional.py, _functions 需要在代码前重新Import

例如我们常用的

import torch.nn.functional as F 就是导入了functional.py

backends和_functions是functional.py实现各种函数时所用到的。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python基于scrapy采集数据时使用代理服务器的方法
Apr 16 Python
python用10行代码实现对黄色图片的检测功能
Aug 10 Python
详解C++编程中一元运算符的重载
Jan 19 Python
Python 统计字数的思路详解
May 08 Python
Python中的类与类型示例详解
Jul 10 Python
python实现简单聊天室功能 可以私聊
Jul 12 Python
Python完全识别验证码自动登录实例详解
Nov 24 Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 Python
Keras 切换后端方式(Theano和TensorFlow)
Jun 19 Python
python用Configobj模块读取配置文件
Sep 26 Python
python绘制汉诺塔
Mar 01 Python
用Python监控你的朋友都在浏览哪些网站?
May 27 Python
python 实现有道翻译功能
Feb 26 #Python
Python爬取酷狗MP3音频的步骤
Feb 26 #Python
python利用xpath爬取网上数据并存储到django模型中
Feb 26 #Python
用python 绘制茎叶图和复合饼图
Feb 26 #Python
python lambda的使用详解
Feb 26 #Python
python爬虫scrapy框架之增量式爬虫的示例代码
Feb 26 #Python
详解Python openpyxl库的基本应用
Feb 26 #Python
You might like
PHP4实际应用经验篇(2)
2006/10/09 PHP
PHP和Mysqlweb应用开发核心技术-第1部分 Php基础-2 php语言介绍
2011/07/03 PHP
laravel学习教程之关联模型
2016/07/30 PHP
用js自动判断浏览器分辨率的代码
2007/01/28 Javascript
向fckeditor编辑器插入指定代码的方法
2007/05/25 Javascript
jquery.mousewheel实现整屏翻屏效果
2015/08/30 Javascript
快速掌握Node.js事件驱动模型
2016/03/21 Javascript
js对象浅拷贝和深拷贝详解
2016/09/05 Javascript
jquery把int类型转换成字符串类型的方法
2016/10/07 Javascript
Bootstrap导航条学习使用(二)
2017/02/08 Javascript
H5图片压缩与上传实例
2017/04/21 Javascript
jquery操作ul的一些操作笔记整理(干货)
2017/08/31 jQuery
从零开始学习搭建React脚手架项目
2018/08/23 Javascript
vue-cli配置flexible过程详解
2019/07/04 Javascript
微信小程序事件 bindtap bindinput代码实例
2019/08/26 Javascript
[46:10]2014 DOTA2国际邀请赛中国区预选赛 CnB VS HGT
2014/05/21 DOTA
Python交换变量
2008/09/06 Python
Python Web服务器Tornado使用小结
2014/05/06 Python
Python线程详解
2015/06/24 Python
Django基础知识与基本应用入门教程
2018/07/20 Python
详解Numpy中的广播原则/机制
2018/09/20 Python
Python 从列表中取值和取索引的方法
2018/12/25 Python
使用django实现一个代码发布系统
2019/07/18 Python
pandas DataFrame的修改方法(值、列、索引)
2019/08/02 Python
Python+kivy BoxLayout布局示例代码详解
2020/12/28 Python
Python3使用tesserocr识别字母数字验证码的实现
2021/01/29 Python
MSC邮轮官方网站:加勒比海、地中海和世界各地的假期
2018/08/27 全球购物
时尚设计师手表:The Watch Cabin
2018/10/06 全球购物
消防安全管理制度
2014/02/01 职场文书
黄继光的英雄事迹材料
2014/02/13 职场文书
多媒体专业自我鉴定
2014/02/28 职场文书
公务员检讨书
2014/11/01 职场文书
JavaScript组合继承详解
2021/11/07 Javascript
详解Python+OpenCV绘制灰度直方图
2022/03/22 Python
vue实现列表垂直无缝滚动
2022/04/08 Vue.js
MySQL 原理与优化之Update 优化
2022/08/14 MySQL