pytorch快速搭建神经网络_Sequential操作


Posted in Python onJune 17, 2020

之前用Class类来搭建神经网络

class Neuro_net(torch.nn.Module):
  """神经网络"""
  def __init__(self, n_feature, n_hidden_layer, n_output):
    super(Neuro_net, self).__init__()
    self.hidden_layer = torch.nn.Linear(n_feature, n_hidden_layer)
    self.output_layer = torch.nn.Linear(n_hidden_layer, n_output)

  def forward(self, input):
    hidden_out = torch.relu(self.hidden_layer(input))
    out = self.output_layer(hidden_out)
    return out
  
net = Neuro_net(2, 10, 2)
print(net)

class类图结构:

pytorch快速搭建神经网络_Sequential操作

使用torch.nn.Sequential() 快速搭建神经网络

net = torch.nn.Sequential(
  torch.nn.Linear(2, 10),
  torch.nn.ReLU(),
  torch.nn.Linear(10, 2)
)
print(net)

Sequential图结构

pytorch快速搭建神经网络_Sequential操作

总结:

我们可以发现,使用torch.nn.Sequential会自动加入激励函数, 但是 class类net 中, 激励函数实际上是在 forward() 功能中才被调用的

使用class类中的torch.nn.Module,我们可以根据自己的需求改变传播过程

如果你需要快速构建或者不需要过多的过程,直接使用torch.nn.Sequential吧

补充知识:【PyTorch神经网络】使用Moudle和Sequential搭建神经网络

Module:

init中定义每个神经层的神经元个数,和神经元层数;

forward是继承nn.Moudle中函数,来实现前向反馈(加上激励函数)

# -*- coding: utf-8 -*-
# @Time  : 2019/11/5 10:43
# @Author : Chen
# @File  : neural_network_impl.py
# @Software: PyCharm
 
import torch
import torch.nn.functional as F
 
#data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
 
 
#第一种搭建方法:Module
# 其中,init中定义每个神经层的神经元个数,和神经元层数;
# forward是继承nn.Moudle中函数,来实现前向反馈(加上激励函数)
class Net(torch.nn.Module):
  def __init__(self):
    #继承__init__函数
    super(Net, self).__init__()
    #定义每层的形式
    #隐藏层线性输出feature->hidden
    self.hidden = torch.nn.Linear(1, 10)
    #输出层线性输出hidden->output
    self.predict = torch.nn.Linear(10, 1)
 
  #实现所有层的连接关系。正向传播输入值,神经网络分析输出值
  def forward(self, x):
    #x首先在隐藏层经过激励函数的计算
    x = F.relu(self.hidden(x))
    #到输出层给出预测值
    x = self.predict(x)
    return x
 
net = Net()
print(net)
 
print('\n\n')
 
#快速搭建:Sequential
#模板:net2 = torch.nn.Sequential()
 
net2 = torch.nn.Sequential(
  torch.nn.Linear(1, 10),
  torch.nn.ReLU(),
  torch.nn.Linear(10, 1)
)
print(net2)

pytorch快速搭建神经网络_Sequential操作

以上这篇pytorch快速搭建神经网络_Sequential操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python登陆asp网站页面的实现代码
Jan 14 Python
Python程序中使用SQLAlchemy时出现乱码的解决方案
Apr 24 Python
Python解析命令行读取参数--argparse模块使用方法
Jan 23 Python
对django xadmin自定义菜单的实例详解
Jan 03 Python
使用Python3+PyQT5+Pyserial 实现简单的串口工具方法
Feb 13 Python
python操作日志的封装方法(两种方法)
May 23 Python
python tkinter canvas使用实例
Nov 04 Python
pd.DataFrame统计各列数值多少的实例
Dec 05 Python
Pandas时间序列重采样(resample)方法中closed、label的作用详解
Dec 10 Python
Python如何访问字符串中的值
Feb 09 Python
python画图时设置分辨率和画布大小的实现(plt.figure())
Jan 08 Python
Python卷积神经网络图片分类框架详解分析
Nov 07 Python
浅谈Keras的Sequential与PyTorch的Sequential的区别
Jun 17 #Python
Keras之fit_generator与train_on_batch用法
Jun 17 #Python
基于Keras的格式化输出Loss实现方式
Jun 17 #Python
Tensorflow之MNIST CNN实现并保存、加载模型
Jun 17 #Python
tensorflow使用CNN分析mnist手写体数字数据集
Jun 17 #Python
解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题
Jun 17 #Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
You might like
PHP Cookie的使用教程详解
2013/06/03 PHP
php判断类是否存在函数class_exists用法分析
2014/11/14 PHP
Javascript 二维数组
2009/11/26 Javascript
js 加载并解析XML字符串的代码
2009/12/13 Javascript
jBox 2.3基于jquery的最新多功能对话框插件 常见使用问题解答
2011/11/10 Javascript
JavaScript中函数声明优先于变量声明的实例分析
2012/03/01 Javascript
关于jQuery对象数据缓存Cache原理以及jQuery.data详解
2013/04/07 Javascript
解决extjs grid 不随窗口大小自适应的改变问题
2014/01/26 Javascript
JS的location.href跳出框架打开新页面的方法
2014/09/04 Javascript
Angularjs 自定义服务的三种方式(推荐)
2016/08/02 Javascript
BootStrap Table 分页后重新搜索问题的解决办法
2016/08/08 Javascript
jQuery ajax MD5实现用户注册即时验证功能
2016/10/11 Javascript
jQuery Validate设置onkeyup验证的实例代码
2016/12/09 Javascript
JavaScript的for循环中嵌套一个点击事件的问题解决
2017/03/03 Javascript
浅谈jQuery框架Ajax常用选项
2017/07/08 jQuery
Vue项目webpack打包部署到Tomcat刷新报404错误问题的解决方案
2018/05/15 Javascript
浅谈node中的cluster集群
2018/06/02 Javascript
JS实现的碰撞检测与周期移动完整示例
2019/09/02 Javascript
vue实现数字动态翻牌的效果(开箱即用)
2019/12/08 Javascript
如何解决vue在ios微信"复制链接"功能问题
2020/03/26 Javascript
在Vue中获取自定义属性方法:data-id的实例
2020/09/09 Javascript
python正则表达式去掉数字中的逗号(python正则匹配逗号)
2013/12/25 Python
简单介绍Python中的decode()方法的使用
2015/05/18 Python
关于Django外键赋值问题详解
2017/08/13 Python
python绘制圆柱体的方法
2018/07/02 Python
Python使用pickle模块实现序列化功能示例
2018/07/13 Python
Flask入门之上传文件到服务器的方法示例
2018/07/18 Python
Python终端输出彩色字符方法详解
2020/02/11 Python
python用Configobj模块读取配置文件
2020/09/26 Python
PHP中如何使用Cookie
2015/10/28 面试题
拓展策划方案
2014/06/03 职场文书
党员对照检查材料整改措施思想汇报
2014/09/26 职场文书
工程部部长岗位职责
2015/02/12 职场文书
大学生村官驻村工作心得体会
2016/01/23 职场文书
SQL Server 忘记密码以及重新添加新账号
2022/04/26 SQL Server
shell进度条追踪指令执行时间的场景分析
2022/06/16 Servers