pytorch神经网络之卷积层与全连接层参数的设置方法


Posted in Python onAugust 18, 2019

当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。

后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。

全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再??铝耍?衷诳?n.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?

请看下文详解。

class AlexNet(nn.Module):
  def __init__(self):
    super(AlexNet, self).__init__()

    self.conv = nn.Sequential(
      nn.Conv2d(3, 96, kernel_size=11, stride=4),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(96, 256, kernel_size=5, padding=2),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2),

      nn.Conv2d(256, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 384, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.Conv2d(384, 256, kernel_size=3, padding=1),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=3, stride=2)
    )

    self.fc = nn.Sequential(
      nn.Linear(???, 4096)
      ......
      ......
    )

首先,我们先把forward写一下:

def forward(self, x):
    x = self.conv(x)
    print x.size()

就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:

import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable

if __name__ == '__main__':
  net = AlexNet()

  data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
  print data_input.size()
  net(data_input)

结果如下:

(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)

显而易见,咱们这个全连接层的input_features为256。

以上这篇pytorch神经网络之卷积层与全连接层参数的设置方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python之多进程和进程池(Processing库)
Jun 09 Python
Python实现扩展内置类型的方法分析
Oct 16 Python
python中实现控制小数点位数的方法
Jan 24 Python
在Python 中同一个类两个函数间变量的调用方法
Jan 31 Python
python爬虫之验证码篇3-滑动验证码识别技术
Apr 11 Python
PyQt QCombobox设置行高的方法
Jun 20 Python
Python3多线程版TCP端口扫描器
Aug 31 Python
基于Python实现2种反转链表方法代码实例
Jul 06 Python
Cpython解释器中的GIL全局解释器锁
Nov 09 Python
pytorch学习教程之自定义数据集
Nov 10 Python
详解使用scrapy进行模拟登陆三种方式
Feb 21 Python
python链表类中获取元素实例方法
Feb 23 Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
pytorch 自定义数据集加载方法
Aug 18 #Python
PyTorch的Optimizer训练工具的实现
Aug 18 #Python
Pytorch反向求导更新网络参数的方法
Aug 17 #Python
pytorch 模型可视化的例子
Aug 17 #Python
pytorch 输出中间层特征的实例
Aug 17 #Python
You might like
星际争霸 Starcraft 秘技补丁
2020/03/14 星际争霸
利用递归把多维数组转为一维数组的函数
2006/10/09 PHP
php在线生成ico文件的代码
2007/10/09 PHP
用phpmyadmin更改mysql5.0登录密码
2008/03/25 PHP
php实现singleton()单例模式实例
2014/11/06 PHP
php中实现记住密码下次自动登录的例子
2014/11/06 PHP
PHP调用微博接口实现微博登录的方法示例
2018/09/22 PHP
PHP中soap用法示例【SoapServer服务端与SoapClient客户端编写】
2018/12/25 PHP
比较详细的关于javascript中void(0)的具体含义解释
2007/08/02 Javascript
不使用中间变量,交换int型的 a, b两个变量的值。
2010/10/29 Javascript
jquery实现图片左右间隔滚动特效(可自动播放)
2013/05/08 Javascript
Jquery easyui 下loaing效果示例代码
2013/08/12 Javascript
js 异步操作回调函数如何控制执行顺序
2013/12/24 Javascript
js判断浏览器类型为ie6时不执行
2014/06/15 Javascript
JavaScript中创建字典对象(dictionary)实例
2015/03/31 Javascript
jQuery实现div拖拽效果实例分析
2016/02/20 Javascript
node.js微信公众平台开发教程
2016/03/04 Javascript
javascript拖拽效果延伸学习
2016/04/04 Javascript
Jquery元素追加和删除的实现方法
2016/05/24 Javascript
JS中使用 after 伪类清除浮动实例
2017/03/01 Javascript
JavaScript学习总结之正则的元字符和一些简单的应用
2017/06/30 Javascript
使用Vue实现简单计算器
2020/02/25 Javascript
element-ui tree结构实现增删改自定义功能代码
2020/08/31 Javascript
Python 流程控制实例代码
2009/09/25 Python
Python中浅拷贝copy与深拷贝deepcopy的简单理解
2018/10/26 Python
Python3 SSH远程连接服务器的方法示例
2018/12/29 Python
Python使用Shelve保存对象方法总结
2019/01/28 Python
python判断单向链表是否包括环,若包含则计算环入口的节点实例分析
2019/10/23 Python
Python远程方法调用实现过程解析
2020/07/28 Python
利用python查看数组中的所有元素是否相同
2021/01/08 Python
Timberland法国官网:购买靴子、鞋子、衣服、夹克和配饰
2019/11/30 全球购物
市场营销专科应届生求职信
2013/11/24 职场文书
运动会广播稿200字
2014/01/15 职场文书
小型婚礼主持词
2015/06/30 职场文书
网吧员工管理制度
2015/08/05 职场文书
解析目标检测之IoU
2021/06/26 Python