pytorch神经网络之卷积层与全连接层参数的设置方法


Posted in Python onAugust 18, 2019

当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。

后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。

全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再??铝耍?衷诳?n.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?

请看下文详解。

class AlexNet(nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(96, 256, kernel_size=5, padding=2),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(???, 4096)
      ......
      ......
    )

首先,我们先把forward写一下:

def forward(self, x):
    x = self.conv(x)
    print x.size()

就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:

import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable

if __name__ == '__main__':
  net = AlexNet()

  data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
  print data_input.size()
  net(data_input)

结果如下:

(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)

显而易见,咱们这个全连接层的input_features为256。

以上这篇pytorch神经网络之卷积层与全连接层参数的设置方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门_学会创建并调用函数的方法
May 16 Python
Python实现的字典值比较功能示例
Jan 08 Python
python中for用来遍历range函数的方法
Jun 08 Python
解决python3捕获cx_oracle抛出的异常错误问题
Oct 18 Python
对python操作kafka写入json数据的简单demo分享
Dec 27 Python
python交互界面的退出方法
Feb 16 Python
python图像和办公文档处理总结
May 28 Python
Pandas之Dropna滤除缺失数据的实现方法
Jun 25 Python
Python制作简易版小工具之计算天数的实现思路
Feb 13 Python
python GUI库图形界面开发之PyQt5控件QTableWidget详细使用方法与属性
Feb 25 Python
pycharm中导入模块错误时提示Try to run this command from the system terminal
Mar 26 Python
Python+Dlib+Opencv实现人脸采集并表情判别功能的代码
Jul 01 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
You might like
PHP5.3与5.5废弃与过期函数整理汇总
2014/07/10 PHP
php+xml编程之xpath的应用实例
2015/01/24 PHP
php导出中文内容excel文件类实例
2015/07/06 PHP
php通过前序遍历树实现无需递归的无限极分类
2015/07/10 PHP
PHP和MYSQL实现分页导航思路详解
2017/04/11 PHP
PHP文件操作实例总结【文件上传、下载、分页】
2018/12/08 PHP
Laravel相关的一些故障解决
2020/08/19 PHP
jquery 选择器部分整理
2009/10/28 Javascript
Javascript异步编程模型Promise模式详细介绍
2014/05/08 Javascript
javascript实现iframe框架延时加载的方法
2014/10/30 Javascript
JS实现带圆弧背景渐变效果的导航菜单代码
2015/10/13 Javascript
简单谈谈JavaScript的同步与异步
2015/12/31 Javascript
jQuery实现导航栏头部菜单项点击后变换颜色的方法
2017/07/19 jQuery
node内置调试方法总结
2018/02/22 Javascript
jQuery实现表格隔行换色
2018/09/01 jQuery
js实现按钮开关单机下拉菜单效果
2018/11/22 Javascript
微信小程序使用wx.request请求服务器json数据并渲染到页面操作示例
2019/03/30 Javascript
js实现整体缩放页面适配移动端
2020/03/31 Javascript
深入解析Python中的__builtins__内建对象
2016/06/21 Python
Python通过调用有道翻译api实现翻译功能示例
2018/07/19 Python
python实现修改固定模式的字符串内容操作示例
2019/12/30 Python
ansible动态Inventory主机清单配置遇到的坑
2020/01/19 Python
python+selenium定时爬取丁香园的新型冠状病毒数据并制作出类似的地图(部署到云服务器)
2020/02/09 Python
python实现音乐播放器 python实现花框音乐盒子
2020/02/25 Python
Django admin 实现search_fields精确查询实例
2020/03/30 Python
Django更新models数据库结构步骤
2020/04/01 Python
Python opencv相机标定实现原理及步骤详解
2020/04/09 Python
HTML5添加鼠标悬浮音响效果不使用FLASH
2014/04/23 HTML / CSS
前台接待员岗位职责
2014/01/02 职场文书
表扬稿格式范文
2015/01/16 职场文书
党校毕业个人总结
2015/02/28 职场文书
你为什么是穷人?可能是这5个缺点造成
2019/07/11 职场文书
教你利用python实现企业微信发送消息
2021/05/23 Python
Python数据可视化之绘制柱状图和条形图
2021/05/25 Python
Python实现byte转integer
2021/06/03 Python
Windows下载并安装MySQL8.0.x 版本的完整教程
2022/04/10 MySQL