pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
wxpython 学习笔记 第一天
Mar 16 Python
Python中文件遍历的两种方法
Jun 16 Python
跟老齐学Python之通过Python连接数据库
Oct 28 Python
python简单分割文件的方法
Jul 30 Python
Win10下Python环境搭建与配置教程
Nov 18 Python
Python实现基本数据结构中栈的操作示例
Dec 04 Python
python十进制和二进制的转换方法(含浮点数)
Jul 07 Python
Python批处理删除和重命名文件夹的实例
Jul 11 Python
使用python根据端口号关闭进程的方法
Nov 06 Python
对python列表里的字典元素去重方法详解
Jan 21 Python
python对象与json相互转换的方法
May 07 Python
python tkinter基本属性详解
Sep 16 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
PHP 存取 MySQL 数据库的一个例子
2006/10/09 PHP
PHP生成静态页
2006/11/25 PHP
Windows下编译PHP5.4和xdebug全记录
2015/04/03 PHP
ThinkPHP模板Volist标签嵌套循环输出多维数组的方法
2016/03/23 PHP
PHP模板引擎Smarty内建函数详解
2016/04/11 PHP
js jquery做的图片连续滚动代码
2008/01/06 Javascript
用C/C++来实现 Node.js 的模块(二)
2014/09/24 Javascript
JavaScript分析、压缩工具JavaScript Analyser
2014/12/31 Javascript
学习JavaScript正则表达式
2015/11/13 Javascript
模板视图和AngularJS之间冲突的解决方法
2016/11/22 Javascript
Bootstrap栅格系统的使用和理解2
2016/12/14 Javascript
Express框架之connect-flash详解
2017/05/31 Javascript
利用原生js实现html5小游戏之打砖块(附源码)
2018/01/03 Javascript
Vue在页面右上角实现可悬浮/隐藏的系统菜单
2018/05/04 Javascript
简单了解node npm cnpm的具体使用方法
2019/02/27 Javascript
详解vue使用$http服务端收不到参数
2019/04/19 Javascript
js 将线性数据转为树形的示例代码
2019/05/28 Javascript
[02:06]DOTA2肉山黑名单魔法终结者 敌法师中文配音鉴赏
2013/06/17 DOTA
Python isinstance判断对象类型
2008/09/06 Python
Python中的getopt函数使用详解
2015/07/28 Python
Python使用gensim计算文档相似性
2016/04/10 Python
python 使用get_argument获取url query参数
2017/04/28 Python
Python基于回溯法子集树模板解决数字组合问题实例
2017/09/02 Python
Python用于学习重要算法的模块pygorithm实例浅析
2018/08/16 Python
python实现kmp算法的实例代码
2019/04/03 Python
Python绘制频率分布直方图的示例
2019/07/08 Python
详解Python对JSON中的特殊类型进行Encoder
2019/07/15 Python
浅谈django url请求与数据库连接池的共享问题
2019/08/29 Python
Python常用数据类型之间的转换总结
2019/09/06 Python
python基于celery实现异步任务周期任务定时任务
2019/12/30 Python
python文件操作seek()偏移量,读取指正到指定位置操作
2020/07/05 Python
英国体育器材进口商店:UK Sport Imports
2017/03/14 全球购物
2015年高三年级组工作总结
2015/07/21 职场文书
SQL Server表分区删除详情
2021/10/16 SQL Server
Spring Boot项目传参校验的最佳实践指南
2022/04/05 Java/Android