pytorch __init__、forward与__call__的用法小结


Posted in Python onFebruary 27, 2021

1.介绍

当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?

1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的

2)forward是表示一个前向传播,构建网络层的先后运算步骤

3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?

当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。

在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

pytorch __init__、forward与__call__的用法小结

2.代码

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def forward(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x
 
class Net(nn.Module):
 def __init__(self, in_channels, mid_channels, out_channels):
 super(Net, self).__init__()
 self.conv0 = torch.nn.Sequential(
 torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 torch.nn.LeakyReLU())
 self.conv1 = torch.nn.Sequential(
 torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
 
 def __call__(self, x):
 x = self.conv0(x)
 x = self.conv1(x)
 return x

补充:torch/nn目录结构以及__init__.py

torch/nn目录结构以及init.py

pytorch __init__、forward与__call__的用法小结

torch/nn目录结构

__init__.py:

from .modules import *
#nn.modules  导入modules目录下内容 定义容器modules
from .parameter import Parameter
#nn.Parameter 导入parameter.py  定义parameter
from .parallel import DataParallel
#导入parallel目录下data_parallel.py中的DataParallel类
from . import init
#nn.init   导入init.py   参数初始化
from . import utils
#nn.utils  导入utils目录下内容 官网api下nn.utils下api

对于backends, functional.py, _functions 需要在代码前重新Import

例如我们常用的

import torch.nn.functional as F 就是导入了functional.py

backends和_functions是functional.py实现各种函数时所用到的。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
python 不关闭控制台的实现方法
Oct 23 Python
Python yield 使用浅析
May 28 Python
Python时间模块datetime、time、calendar的使用方法
Jan 13 Python
python对象及面向对象技术详解
Jul 19 Python
Python3实现并发检验代理池地址的方法
Sep 18 Python
解决python爬虫中有中文的url问题
May 11 Python
python numpy实现文件存取的示例代码
May 26 Python
在Python中append以及extend返回None的例子
Jul 20 Python
python批量修改ssh密码的实现
Aug 08 Python
PyCharm 在Windows的有用快捷键详解
Apr 07 Python
python 给图像添加透明度(alpha通道)
Apr 09 Python
Python Request类源码实现方法及原理解析
Aug 17 Python
python 实现有道翻译功能
Feb 26 #Python
Python爬取酷狗MP3音频的步骤
Feb 26 #Python
python利用xpath爬取网上数据并存储到django模型中
Feb 26 #Python
用python 绘制茎叶图和复合饼图
Feb 26 #Python
python lambda的使用详解
Feb 26 #Python
python爬虫scrapy框架之增量式爬虫的示例代码
Feb 26 #Python
详解Python openpyxl库的基本应用
Feb 26 #Python
You might like
mysql总结之explain
2012/02/27 PHP
解析php 版获取重定向后的地址(代码)
2013/06/26 PHP
PHP_NETWORK_GETADDRESSES: GETADDRINFO FAILED问题解决办法
2014/05/04 PHP
php实现word转html的方法
2016/01/22 PHP
PHP编程实现的TCP服务端和客户端功能示例
2018/04/13 PHP
javascript预览上传图片发现的问题的解决方法
2010/11/25 Javascript
js播放wav文件(源码)
2013/04/22 Javascript
window.showModalDialog参数传递中含有特殊字符的处理方法
2013/06/06 Javascript
JS 操作Array数组的方法及属性实例解析
2014/01/08 Javascript
jquery操作HTML5 的data-*的用法实例分享
2014/08/17 Javascript
JavaScript点击按钮后弹出透明浮动层的方法
2015/05/11 Javascript
JS自定义函数对web前端上传的文件进行类型大小判断
2016/10/19 Javascript
jQuery图片切换动画特效
2016/11/02 Javascript
JavaScript之Map和Set_动力节点Java学院整理
2017/06/29 Javascript
DVA框架统一处理所有页面的loading状态
2017/08/25 Javascript
基于es6三点运算符的使用方法(实例讲解)
2017/10/12 Javascript
Angular实现点击按钮控制隐藏和显示功能示例
2017/12/29 Javascript
使用javascript做时间倒数读秒功能的实例
2019/01/23 Javascript
微信小程序实现的canvas合成图片功能示例
2019/05/03 Javascript
微信小程序上传帖子的实例代码(含有文字图片的微信验证)
2020/07/11 Javascript
零基础写python爬虫之抓取糗事百科代码分享
2014/11/06 Python
python执行使用shell命令方法分享
2017/11/08 Python
Python之文字转图片方法
2018/05/10 Python
对python 中re.sub,replace(),strip()的区别详解
2019/07/22 Python
python 3.6.7实现端口扫描器
2019/09/04 Python
在python中做正态性检验示例
2019/12/09 Python
Python并发concurrent.futures和asyncio实例
2020/05/04 Python
css3实现针线缝合效果(图解步骤)
2013/02/04 HTML / CSS
日语专业求职信
2014/07/04 职场文书
法院反腐倡廉心得体会
2014/09/09 职场文书
古诗文之爱国名句(77句)
2019/09/24 职场文书
古诗之感恩老师
2019/10/24 职场文书
go原生库的中bytes.Buffer用法
2021/04/25 Golang
2022年四月新番
2022/03/15 日漫
Go 内联优化让程序员爱不释手
2022/06/21 Golang
ORACLE中dbms_output.put_line输出问题的解决过程
2022/06/28 Oracle