pytorch神经网络之卷积层与全连接层参数的设置方法


Posted in Python onAugust 18, 2019

当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。

后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。

全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再??铝耍?衷诳?n.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?

请看下文详解。

class AlexNet(nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(96, 256, kernel_size=5, padding=2),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(???, 4096)
      ......
      ......
    )

首先,我们先把forward写一下:

def forward(self, x):
    x = self.conv(x)
    print x.size()

就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:

import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable

if __name__ == '__main__':
  net = AlexNet()

  data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
  print data_input.size()
  net(data_input)

结果如下:

(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)

显而易见,咱们这个全连接层的input_features为256。

以上这篇pytorch神经网络之卷积层与全连接层参数的设置方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 简单的多线程链接实现代码
Aug 28 Python
Python实现两个list对应元素相减操作示例
Jun 09 Python
Python找出微信上删除你好友的人脚本写法
Nov 01 Python
神经网络相关之基础概念的讲解
Dec 29 Python
详解Python的三种拷贝方式
Feb 11 Python
Python 实现Image和Ndarray互相转换
Feb 19 Python
Python的Django框架实现数据库查询(不返回QuerySet的方法)
May 19 Python
python db类用法说明
Jul 07 Python
python解决OpenCV在读取显示图片的时候闪退的问题
Feb 23 Python
Python一行代码实现自动发邮件功能
May 30 Python
详解Python描述符的工作原理
Jun 11 Python
python基础入门之字典和集合
Jun 13 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
You might like
PHP获取文件绝对路径的代码(上一级目录)
2011/05/29 PHP
php导出word文档与excel电子表格的简单示例代码
2014/03/08 PHP
php短网址和数字之间相互转换的方法
2015/03/13 PHP
PHP+jQuery+Ajax实现分页效果 jPaginate插件的应用
2015/10/09 PHP
Laravel最佳分割路由文件(routes.php)的方式
2016/08/04 PHP
PHP hebrev()函数用法讲解
2019/02/21 PHP
CSS中简写属性要注意TRouBLe的顺序问题(避免踩坑)
2021/03/09 HTML / CSS
Javascript调用XML制作连动下拉列表框
2006/06/25 Javascript
修改好的jquery滚动字幕效果实现代码
2011/06/22 Javascript
jQuery插件Tooltipster实现漂亮的工具提示
2015/04/12 Javascript
php常见的页面跳转方法汇总
2015/04/15 Javascript
jQuery的实例及必知重要的jQuery选择器详解
2016/05/20 Javascript
JS使用cookie设置样式的方法
2016/06/30 Javascript
在web中js实现类似excel的表格控件
2016/09/01 Javascript
9个让JavaScript调试更简单的Console命令
2016/11/14 Javascript
vue页面跳转后返回原页面初始位置方法
2018/02/11 Javascript
angular 实现下拉列表组件的示例代码
2019/03/09 Javascript
详解在HTTPS 项目中使用百度地图 API
2019/04/26 Javascript
js实现带搜索功能的下拉框
2020/01/11 Javascript
使用vue打包进行云服务器上传的问题
2020/03/02 Javascript
Python找出list中最常出现元素的方法
2016/06/14 Python
Python实现的矩阵类实例
2017/08/22 Python
Python Flask框架扩展操作示例
2019/05/03 Python
python面试题Python2.x和Python3.x的区别
2019/05/28 Python
Python音频操作工具PyAudio上手教程详解
2019/06/26 Python
python return逻辑判断表达式实现解析
2019/12/02 Python
Python接口自动化判断元素原理解析
2020/02/24 Python
Jmeter HTTPS接口测试证书导入过程图解
2020/07/22 Python
python代码实现图书管理系统
2020/11/30 Python
瑞典网上购买现代和复古家具:Reforma
2019/10/21 全球购物
市场营销职业生涯规划书范文
2014/01/12 职场文书
前处理组长岗位职责
2014/03/01 职场文书
货车司机岗位职责
2014/03/18 职场文书
学校清明节活动总结
2014/07/04 职场文书
2016年五四青年节校园广播稿
2015/12/17 职场文书
Python GUI编程之tkinter 关于 ttkbootstrap 的使用详解
2022/03/03 Python