pytorch神经网络之卷积层与全连接层参数的设置方法


Posted in Python onAugust 18, 2019

当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。

后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。

全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再??铝耍?衷诳?n.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?

请看下文详解。

class AlexNet(nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(96, 256, kernel_size=5, padding=2),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(???, 4096)
      ......
      ......
    )

首先,我们先把forward写一下:

def forward(self, x):
    x = self.conv(x)
    print x.size()

就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:

import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable

if __name__ == '__main__':
  net = AlexNet()

  data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
  print data_input.size()
  net(data_input)

结果如下:

(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)

显而易见,咱们这个全连接层的input_features为256。

以上这篇pytorch神经网络之卷积层与全连接层参数的设置方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中有趣在__call__函数
Jun 21 Python
详解python string类型 bytes类型 bytearray类型
Dec 16 Python
Django教程笔记之中间件middleware详解
Aug 01 Python
python实现键盘控制鼠标移动
Nov 27 Python
Python判断以什么结尾以什么开头的实例
Oct 27 Python
Python实现批量修改图片格式和大小的方法【opencv库与PIL库】
Dec 03 Python
Python遍历文件夹 处理json文件的方法
Jan 22 Python
python Elasticsearch索引建立和数据的上传详解
Aug 04 Python
python 解压、复制、删除 文件的实例代码
Feb 26 Python
python 自动识别并连接串口的实现
Jan 19 Python
使用pytorch实现线性回归
Apr 11 Python
详解Python为什么不用设计模式
Jun 24 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
You might like
对squid中refresh_pattern的一些理解和建议
2009/04/17 PHP
php遍历类中包含的所有元素的方法
2015/05/12 PHP
php实现的操作excel类详解
2016/01/15 PHP
功能强大的php分页函数
2016/07/20 PHP
(JS实现)MapBar中坐标的加密和解密的脚本
2007/05/16 Javascript
js AspxButton的客户端操作
2009/06/26 Javascript
JavaScript和ActionScript的交互实现代码
2010/08/01 Javascript
JavaScript小技巧 2.5 则
2010/09/12 Javascript
jQuery 源码分析笔记(2) 变量列表
2011/05/28 Javascript
js使用心得分享
2015/01/13 Javascript
JS创建对象几种不同方法详解
2016/03/01 Javascript
JS实现动态给标签控件添加事件的方法示例
2017/05/13 Javascript
详解Vue整合axios的实例代码
2017/06/21 Javascript
详解AngularJS2 Http服务
2017/06/26 Javascript
vue中如何创建多个ueditor实例教程
2017/11/14 Javascript
关于react中组件通信的几种方式详解
2017/12/10 Javascript
Vue不能观察到数组length的变化
2018/06/08 Javascript
JS通过ajax + 多列布局 + 自动加载实现瀑布流效果
2019/05/30 Javascript
react实现antd线上主题动态切换功能
2019/08/12 Javascript
js利用递归与promise 按顺序请求数据的方法
2019/08/30 Javascript
vuex入门最详细整理
2020/03/04 Javascript
[14:20]刀塔大凶女神互压各路奇葩屌丝
2014/05/16 DOTA
在Python中使用CasperJS获取JS渲染生成的HTML内容的教程
2015/04/09 Python
简述Python中的面向对象编程的概念
2015/04/27 Python
在ironpython中利用装饰器执行SQL操作的例子
2015/05/02 Python
python 多线程实现检测服务器在线情况
2015/11/25 Python
分析python切片原理和方法
2017/12/19 Python
解决python删除文件的权限错误问题
2018/04/24 Python
Pycharm Git 设置方法
2020/09/15 Python
通用的Django注册功能模块实现方法
2021/02/05 Python
Tory Burch德国官网:美国时尚生活品牌
2018/01/03 全球购物
英国珠宝和手表专家:Pleasance & Harper
2020/10/21 全球购物
2014年出纳工作总结与计划
2014/12/09 职场文书
2015年暑期社会实践报告
2015/07/13 职场文书
幼儿园园长新年寄语
2015/08/17 职场文书
2021-4-3课程——SQL Server查询【2】
2021/04/05 SQL Server