pytorch构建网络模型的4种方法


Posted in Python onApril 13, 2018

利用pytorch来构建网络模型有很多种方法,以下简单列出其中的四种。

假设构建一个网络模型如下:

卷积层--》Relu层--》池化层--》全连接层--》Relu层--》全连接层

首先导入几种方法用到的包:

import torch
import torch.nn.functional as F
from collections import OrderedDict

第一种方法

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2(x)
    return x

print("Method 1:")
model1 = Net1()
print(model1)

这种方法比较常用,早期的教程通常就是使用这种方法。

pytorch构建网络模型的4种方法

第二种方法

# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)

pytorch构建网络模型的4种方法

这种方法利用torch.nn.Sequential()容器进行快速搭建,模型的各层被顺序添加到容器中。缺点是每层的编号是默认的阿拉伯数字,不易区分。

第三种方法:

# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)

pytorch构建网络模型的4种方法

这种方法是对第二种方法的改进:通过add_module()添加每一层,并且为每一层增加了一个单独的名字。 

第四种方法:

# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

pytorch构建网络模型的4种方法

是第三种方法的另外一种写法,通过字典的形式添加每一层,并且设置单独的层名称。

完整代码:

import torch
import torch.nn.functional as F
from collections import OrderedDict

# Method 1 -----------------------------------------

class Net1(torch.nn.Module):
  def __init__(self):
    super(Net1, self).__init__()
    self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
    self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
    self.dense2 = torch.nn.Linear(128, 10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv(x)), 2)
    x = x.view(x.size(0), -1)
    x = F.relu(self.dense1(x))
    x = self.dense2()
    return x

print("Method 1:")
model1 = Net1()
print(model1)


# Method 2 ------------------------------------------
class Net2(torch.nn.Module):
  def __init__(self):
    super(Net2, self).__init__()
    self.conv = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(32 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 2:")
model2 = Net2()
print(model2)


# Method 3 -------------------------------
class Net3(torch.nn.Module):
  def __init__(self):
    super(Net3, self).__init__()
    self.conv=torch.nn.Sequential()
    self.conv.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
    self.conv.add_module("relu1",torch.nn.ReLU())
    self.conv.add_module("pool1",torch.nn.MaxPool2d(2))
    self.dense = torch.nn.Sequential()
    self.dense.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
    self.dense.add_module("relu2",torch.nn.ReLU())
    self.dense.add_module("dense2",torch.nn.Linear(128, 10))

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 3:")
model3 = Net3()
print(model3)



# Method 4 ------------------------------------------
class Net4(torch.nn.Module):
  def __init__(self):
    super(Net4, self).__init__()
    self.conv = torch.nn.Sequential(
      OrderedDict(
        [
          ("conv1", torch.nn.Conv2d(3, 32, 3, 1, 1)),
          ("relu1", torch.nn.ReLU()),
          ("pool", torch.nn.MaxPool2d(2))
        ]
      ))

    self.dense = torch.nn.Sequential(
      OrderedDict([
        ("dense1", torch.nn.Linear(32 * 3 * 3, 128)),
        ("relu2", torch.nn.ReLU()),
        ("dense2", torch.nn.Linear(128, 10))
      ])
    )

  def forward(self, x):
    conv_out = self.conv1(x)
    res = conv_out.view(conv_out.size(0), -1)
    out = self.dense(res)
    return out

print("Method 4:")
model4 = Net4()
print(model4)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中偏函数partial用法实例分析
Jul 08 Python
在Django框架中编写Contact表单的教程
Jul 17 Python
Python实现Sqlite将字段当做索引进行查询的方法
Jul 21 Python
python实现给微信公众号发送消息的方法
Jun 30 Python
linux环境下的python安装过程图解(含setuptools)
Nov 22 Python
python绘制条形图方法代码详解
Dec 19 Python
python range()函数取反序遍历sequence的方法
Jun 25 Python
python+selenium实现简历自动刷新的示例代码
May 20 Python
pytorch 实现查看网络中的参数
Jan 06 Python
tensorflow 实现数据类型转换
Feb 17 Python
Python库安装速度过慢解决方案
Jul 14 Python
python使用pywinauto驱动微信客户端实现公众号爬虫
May 19 Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
对numpy的array和python中自带的list之间相互转化详解
Apr 13 #Python
You might like
在PHP中使用XML
2006/10/09 PHP
php实现在多维数组中查找特定value的方法
2015/07/29 PHP
PHP获取当前URL路径的处理方法(适用于多条件筛选列表)
2017/02/10 PHP
PHP程序员必须知道的两种日志实例分析
2020/05/14 PHP
javascript面向对象之Javascript 继承
2010/05/04 Javascript
div+css布局的图片连续滚动js实现代码
2010/05/04 Javascript
javascript 使td内容不换行不撑开
2012/11/29 Javascript
兼容IE和Firefox火狐的上下、左右循环无间断滚动JS代码
2013/04/19 Javascript
javascript匿名函数应用示例介绍
2014/03/07 Javascript
jQuery将多条数据插入模态框的示例代码
2014/09/25 Javascript
利用AJAX实现WordPress中的文章列表及评论的分页功能
2016/05/17 Javascript
js与jquery正则验证电子邮箱、手机号、邮政编码的方法
2016/07/04 Javascript
前端面试题及答案整理(二)
2016/08/26 Javascript
Jquery树插件zTree实现菜单树
2017/01/24 Javascript
jQuery实现简单漂亮的Nav导航菜单效果
2017/03/29 jQuery
Vue CLI 3.x 自动部署项目至服务器的方法
2019/04/02 Javascript
微信小程序实现拖拽功能
2019/09/26 Javascript
JS Web Flex弹性盒子模型代码实例
2020/03/10 Javascript
VUE+elementui组件在table-cell单元格中绘制微型echarts图
2020/04/20 Javascript
python实现sublime3的less编译插件示例
2014/04/27 Python
Python实现将xml导入至excel
2015/11/20 Python
Python2.x利用commands模块执行Linux shell命令
2016/03/11 Python
python 容器总结整理
2017/04/04 Python
Python实现运行其他程序的四种方式实例分析
2017/08/17 Python
Pycharm debug调试时带参数过程解析
2020/02/03 Python
如何在mac下配置python虚拟环境
2020/07/06 Python
Django多数据库联用实现方法解析
2020/11/12 Python
python 多线程爬取壁纸网站的示例
2021/02/20 Python
html5/css3响应式页面开发总结
2018/10/16 HTML / CSS
台湾屈臣氏网路商店:Watsons台湾
2020/12/29 全球购物
《问银河》教学反思
2014/02/19 职场文书
工作决心书
2014/03/11 职场文书
2015年个人自我剖析材料
2014/12/29 职场文书
公司庆典欢迎词
2015/01/26 职场文书
郭明义观后感
2015/06/08 职场文书
标准版个人借条怎么写?以及什么是借条?
2019/08/28 职场文书