pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(一)
Jun 09 Python
Python访问MySQL封装的常用类实例
Nov 11 Python
简单解析Django框架中的表单验证
Jul 17 Python
Python3.6基于正则实现的计算器示例【无优化简单注释版】
Jun 14 Python
浅谈Python的list中的选取范围
Nov 12 Python
python实现逐个读取txt字符并修改
Dec 24 Python
python3使用matplotlib绘制散点图
Mar 19 Python
python并发编程 Process对象的其他属性方法join方法详解
Aug 20 Python
python元组的概念知识点
Nov 19 Python
基于python实现学生信息管理系统
Nov 22 Python
Python的Django框架实现数据库查询(不返回QuerySet的方法)
May 19 Python
Tensorflow全局设置可见GPU编号操作
Jun 30 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
在php中判断一个请求是ajax请求还是普通请求的方法
2011/06/28 PHP
3款值得推荐的微信开发开源框架
2014/10/28 PHP
PHP+JS实现大规模数据提交的方法
2015/07/02 PHP
3种方法轻松处理php开发中emoji表情的问题
2016/07/18 PHP
总结PHP中DateTime的常用方法
2016/08/11 PHP
PHP后台实现微信小程序登录
2018/08/03 PHP
基于thinkphp5框架实现微信小程序支付 退款 订单查询 退款查询操作
2020/08/17 PHP
JavaScript DOM学习第四章 getElementByTagNames
2010/02/19 Javascript
javascript 防止刷新,后退,关闭
2010/08/07 Javascript
基于jQuery+HttpHandler实现图片裁剪效果代码(适用于论坛, SNS)
2011/09/02 Javascript
节点的插入之append()和appendTo()的用法介绍
2014/01/13 Javascript
Jquery+asp.net后台数据传到前台js进行解析的方法
2014/05/11 Javascript
JS实现图片无间断滚动代码汇总
2014/07/30 Javascript
javascript使用shift+click实现选择和反选checkbox的方法
2015/05/04 Javascript
js实现仿爱微网两级导航菜单效果代码
2015/08/31 Javascript
JavaScript知识点总结之如何提高性能
2016/01/15 Javascript
Javascript 实现微信分享(QQ、朋友圈、分享给朋友)
2016/10/21 Javascript
原生JavaScript来实现对dom元素class的操作方法(推荐)
2017/08/16 Javascript
基于jquery trigger函数无法触发a标签的两种解决方法
2018/01/06 jQuery
jQuery实现带右侧索引功能的通讯录示例【附源码下载】
2018/04/17 jQuery
IE浏览器下JS脚本提交表单后,不能自动提示问题解决方法
2019/06/04 Javascript
js实现一个简易计算器
2020/03/30 Javascript
Ant Design的可编辑Tree的实现操作
2020/10/31 Javascript
Python中几种操作字符串的方法的介绍
2015/04/09 Python
python图像处理之镜像实现方法
2015/05/30 Python
Python中type的构造函数参数含义说明
2015/06/21 Python
深入学习python多线程与GIL
2019/08/26 Python
解决python -m pip install --upgrade pip 升级不成功问题
2020/03/05 Python
tensorflow转换ckpt为savermodel模型的实现
2020/05/25 Python
python 生成器需注意的小问题
2020/09/29 Python
携程旅行网:中国领先的在线旅行服务公司
2017/02/17 全球购物
德国高尔夫商店:Golfshop.de
2019/06/22 全球购物
进程的查看和调度分别使用什么命令
2015/03/25 面试题
周年庆促销方案
2014/03/15 职场文书
小组口号大全
2014/06/09 职场文书
Spring boot实现上传文件到本地服务器
2022/08/14 Java/Android