pytorch神经网络之卷积层与全连接层参数的设置方法


Posted in Python onAugust 18, 2019

当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。

后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。

全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再??铝耍?衷诳?n.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?

请看下文详解。

class AlexNet(nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(96, 256, kernel_size=5, padding=2),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(???, 4096)
      ......
      ......
    )

首先,我们先把forward写一下:

def forward(self, x):
    x = self.conv(x)
    print x.size()

就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:

import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable

if __name__ == '__main__':
  net = AlexNet()

  data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
  print data_input.size()
  net(data_input)

结果如下:

(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)

显而易见,咱们这个全连接层的input_features为256。

以上这篇pytorch神经网络之卷积层与全连接层参数的设置方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作日期和时间的方法
Mar 11 Python
python中常用的各种数据库操作模块和连接实例
May 29 Python
Python的gevent框架的入门教程
Apr 29 Python
微信跳一跳python辅助软件思路及图像识别源码解析
Jan 04 Python
python实现智能语音天气预报
Dec 02 Python
Python使用uuid库生成唯一标识ID
Feb 12 Python
Python变量及数据类型用法原理汇总
Aug 06 Python
scrapy与selenium结合爬取数据(爬取动态网站)的示例代码
Sep 28 Python
python 提高开发效率的5个小技巧
Oct 19 Python
python在协程中增加任务实例操作
Feb 28 Python
pytorch 中nn.Dropout的使用说明
May 20 Python
python数字图像处理:图像的绘制
Jun 28 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
You might like
mysql数据库差异比较的PHP代码
2012/02/05 PHP
thinkPHP中create方法与令牌验证实例浅析
2015/12/08 PHP
ThinkPHP3.2框架使用addAll()批量插入数据的方法
2017/03/16 PHP
php微信公众号开发之快递查询
2018/10/20 PHP
js加解密 脚本解密
2008/02/22 Javascript
TextArea不支持maxlength的解决办法(jquery)
2011/09/13 Javascript
JQuery 常用方法和事件详细介绍
2013/04/18 Javascript
简约JS日历控件 实例代码
2013/07/12 Javascript
Javascript毫秒数用法实例
2015/02/05 Javascript
JS基于Mootools实现的个性菜单效果代码
2015/10/21 Javascript
js操作table元素实现表格行列新增、删除技巧总结
2015/11/18 Javascript
基于jQuery实现收缩展开功能
2016/03/18 Javascript
用JS实现图片轮播效果代码(一)
2016/06/26 Javascript
基于SpringMVC+Bootstrap+DataTables实现表格服务端分页、模糊查询
2016/10/30 Javascript
原生javascript移动端滑动banner效果
2017/03/10 Javascript
vue3.0 CLI - 2.3 - 组件 home.vue 中学习指令和绑定
2018/09/14 Javascript
微信小程序使用npm支持踩坑
2018/11/07 Javascript
如何自动化部署项目?折腾服务器之旅~
2019/04/16 Javascript
小程序实现背景音乐播放和暂停
2020/06/19 Javascript
python 根据正则表达式提取指定的内容实例详解
2016/12/04 Python
python数据结构之链表的实例讲解
2017/07/25 Python
python实现数独游戏 java简单实现数独游戏
2018/03/30 Python
Java与Python两大幸存者谁更胜一筹呢
2018/04/12 Python
python opencv实现切变换 不裁减图片
2018/07/26 Python
Python QTimer实现多线程及QSS应用过程解析
2020/07/11 Python
python缩进长度是否统一
2020/08/02 Python
CSS3 网页下拉菜单代码解释 中文翻译
2010/02/27 HTML / CSS
Clarks其乐鞋荷兰官网:Clarks荷兰
2019/07/05 全球购物
护理专业本科生自荐信
2013/10/01 职场文书
优质的学校老师推荐信
2013/10/28 职场文书
见习期自我鉴定
2013/11/07 职场文书
材料物理专业个人求职信
2013/12/15 职场文书
大学生入党推荐书范文
2014/05/17 职场文书
2014年廉洁自律承诺书
2014/05/26 职场文书
Redis Cluster 集群搭建你会吗
2021/08/04 Redis
MySQL8.0 Undo Tablespace管理详解
2022/06/16 MySQL