pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3.2中的字符串函数学习总结
Apr 23 Python
使用Python压缩和解压缩zip文件的教程
May 06 Python
Collatz 序列、逗号代码、字符图网格实例
Jun 22 Python
python使用xslt提取网页数据的方法
Feb 23 Python
Python读取视频的两种方法(imageio和cv2)
Apr 15 Python
Matplotlib 生成不同大小的subplots实例
May 25 Python
详解python里的命名规范
Jul 16 Python
python tkinter基本属性详解
Sep 16 Python
Python生成器实现简单"生产者消费者"模型代码实例
Mar 27 Python
Python使用tkinter实现摇骰子小游戏功能的代码
Jul 02 Python
scrapy头部修改的方法详解
Dec 06 Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
php 传值赋值与引用赋值的区别
2010/12/29 PHP
php5.3 goto函数介绍和示例
2014/03/21 PHP
用 Composer构建自己的 PHP 框架之基础准备
2014/10/30 PHP
PHP+MySql+jQuery实现的"顶"和"踩"投票功能
2016/05/21 PHP
详解PHP中的 input属性(隐藏 只读 限制)
2017/08/14 PHP
Mootools 1.2教程 Fx.Tween的使用
2009/09/15 Javascript
js实现的真正的iframe高度自适应(兼容IE,FF,Opera)
2010/03/07 Javascript
JavaScript中document.forms[0]与getElementByName区别
2015/01/21 Javascript
高性能JavaScript模板引擎实现原理详解
2015/02/05 Javascript
js用拖动滑块来控制图片大小的方法
2015/02/27 Javascript
JavaScript获取页面中表单(form)数量的方法
2015/04/03 Javascript
基于jQuery实现淡入淡出效果轮播图
2020/07/31 Javascript
Angular之指令Directive用法详解
2017/03/01 Javascript
JavaScript简介_动力节点Java学院整理
2017/06/26 Javascript
使用travis-ci如何持续部署node.js应用详解
2017/07/30 Javascript
深入浅出webpack教程系列_安装与基本打包用法和命令参数详解
2017/09/10 Javascript
微信小程序列表中item左滑删除功能
2018/11/07 Javascript
JQuery事件委托(适用于给动态生成的脚本元素添加事件)
2020/02/01 jQuery
[01:10:30]DOTA2-DPC中国联赛正赛 Dragon vs Dynasty BO3 第一场 3月4日
2021/03/11 DOTA
Python结合ImageMagick实现多张图片合并为一个pdf文件的方法
2018/04/24 Python
Flask实现图片的上传、下载及展示示例代码
2018/08/03 Python
在Pycharm中自动添加时间日期作者等信息的方法
2019/01/16 Python
python 多线程串行和并行的实例
2019/02/22 Python
Python autoescape标签用法解析
2020/01/17 Python
英国复古服装购物网站:Collectif
2019/10/30 全球购物
科颜氏香港官方网店:Kiehl’s香港
2021/03/07 全球购物
Java文件和目录(IO)操作
2014/08/26 面试题
师范毕业生求职自荐信
2013/09/25 职场文书
生产部统计员岗位职责
2014/01/05 职场文书
办理生育手续介绍信
2014/01/14 职场文书
教师个人读书活动总结
2014/07/08 职场文书
优秀中职教师事迹材料
2014/08/26 职场文书
大学生实习证明范本
2014/09/19 职场文书
群众路线查摆问题整改措施
2014/10/10 职场文书
担保书怎么写 ?
2019/04/22 职场文书
SQL Server 数据库实验课第五周——常用查询条件
2021/04/05 SQL Server