pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 字符串中的字符倒转
Sep 06 Python
python中使用psutil查看内存占用的情况
Jun 11 Python
Python中的取模运算方法
Nov 10 Python
python 生成图形验证码的方法示例
Nov 11 Python
python实现动态创建类的方法分析
Jun 25 Python
python频繁写入文件时提速的方法
Jun 26 Python
Django文件存储 自己定制存储系统解析
Aug 02 Python
应用OpenCV和Python进行SIFT算法的实现详解
Aug 21 Python
Pandas-Cookbook 时间戳处理方式
Dec 07 Python
pytorch中tensor张量数据类型的转化方式
Dec 31 Python
使用python批量修改XML文件中图像的depth值
Jul 22 Python
Django使用echarts进行可视化展示的实践
Jun 10 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
thinkphp学习笔记之多表查询
2014/07/28 PHP
php定义参数数量可变的函数用法实例
2015/03/16 PHP
WordPress中给媒体文件添加分类和标签的PHP功能实现
2015/12/31 PHP
Javascript实例教程(19) 使用HoTMetal(7)
2006/12/23 Javascript
javascript对talbe进行动态添加、删除、验证实现代码
2012/03/29 Javascript
ie6下png图片背景不透明的解决办法使用js实现
2013/01/11 Javascript
php跨域调用json的例子
2013/11/13 Javascript
JavaScript中停止执行setInterval和setTimeout事件的方法
2015/05/14 Javascript
jQuery 插件实现随机自由弹跳气泡样式
2017/01/12 Javascript
巧用weui.topTips验证数据的实例
2017/04/17 Javascript
ES6新特性八:async函数用法实例详解
2017/04/21 Javascript
Vue resource中的GET与POST请求的实例代码
2017/07/21 Javascript
详解从Vue.js源码看异步更新DOM策略及nextTick
2017/10/11 Javascript
五步轻松实现zTree的使用
2017/11/01 Javascript
vue 解决addRoutes动态添加路由后刷新失效问题
2018/07/02 Javascript
react-router 路由切换动画的实现示例
2018/12/03 Javascript
vue登录页实现使用cookie记住7天密码功能的方法
2021/02/18 Vue.js
使用python装饰器验证配置文件示例
2014/02/24 Python
用Django实现一个可运行的区块链应用
2018/03/08 Python
python使用rpc框架gRPC的方法
2018/08/24 Python
django利用request id便于定位及给日志加上request_id
2018/08/26 Python
Python 使用多属性来进行排序
2019/09/01 Python
python实现跨excel sheet复制代码实例
2020/03/03 Python
Python的PIL库中getpixel方法的使用
2020/04/09 Python
python openCV实现摄像头获取人脸图片
2020/08/20 Python
Python+OpenCV图像处理——实现直线检测
2020/10/23 Python
python接口自动化框架实战
2020/12/23 Python
CSS3按钮鼠标悬浮实现光圈效果源码
2016/09/11 HTML / CSS
html5移动端自适应布局的实现
2020/04/15 HTML / CSS
Structs界面控制层技术
2013/10/11 面试题
自动化工程专业个人应聘自荐信
2013/09/26 职场文书
酒店营销策划方案
2014/02/07 职场文书
成品库仓管员岗位职责
2014/04/06 职场文书
2014年教师业务学习材料
2014/05/12 职场文书
python删除csv文件的行列
2021/04/06 Python
web前端之css水平居中代码解析
2021/05/20 HTML / CSS