基于pytorch的lstm参数使用详解


Posted in Python onJanuary 14, 2020

lstm(*input, **kwargs)

将多层长短时记忆(LSTM)神经网络应用于输入序列。

参数:

input_size:输入'x'中预期特性的数量

hidden_size:隐藏状态'h'中的特性数量

num_layers:循环层的数量。例如,设置' ' num_layers=2 ' '意味着将两个LSTM堆叠在一起,形成一个'堆叠的LSTM ',第二个LSTM接收第一个LSTM的输出并计算最终结果。默认值:1

bias:如果' False',则该层不使用偏置权重' b_ih '和' b_hh '。默认值:'True'

batch_first:如果' 'True ' ',则输入和输出张量作为(batch, seq, feature)提供。默认值: 'False'

dropout:如果非零,则在除最后一层外的每个LSTM层的输出上引入一个“dropout”层,相当于:attr:'dropout'。默认值:0

bidirectional:如果‘True',则成为双向LSTM。默认值:'False'

输入:input,(h_0, c_0)

**input**of shape (seq_len, batch, input_size):包含输入序列特征的张量。输入也可以是一个压缩的可变长度序列。

see:func:'torch.nn.utils.rnn.pack_padded_sequence' 或:func:'torch.nn.utils.rnn.pack_sequence' 的细节。

**h_0** of shape (num_layers * num_directions, batch, hidden_size):张量包含批处理中每个元素的初始隐藏状态。

如果RNN是双向的,num_directions应该是2,否则应该是1。

**c_0** of shape (num_layers * num_directions, batch, hidden_size):张量包含批处理中每个元素的初始单元格状态。

如果没有提供' (h_0, c_0) ',则**h_0**和**c_0**都默认为零。

输出:output,(h_n, c_n)

**output**of shape (seq_len, batch, num_directions * hidden_size) :包含LSTM最后一层输出特征' (h_t) '张量,

对于每个t. If a:class: 'torch.nn.utils.rnn.PackedSequence' 已经给出,输出也将是一个打包序列。

对于未打包的情况,可以使用'output.view(seq_len, batch, num_directions, hidden_size)',正向和反向分别为方向' 0 '和' 1 '。

同样,在包装的情况下,方向可以分开。

**h_n** of shape (num_layers * num_directions, batch, hidden_size):包含' t = seq_len '隐藏状态的张量。

与*output*类似, the layers可以使用以下命令分隔

h_n.view(num_layers, num_directions, batch, hidden_size) 对于'c_n'相似

**c_n** (num_layers * num_directions, batch, hidden_size):张量包含' t = seq_len '的单元状态

所有的权重和偏差都初始化自: 基于pytorch的lstm参数使用详解 where: 基于pytorch的lstm参数使用详解

include:: cudnn_persistent_rnn.rst
import torch
import torch.nn as nn
 
# 双向rnn例子
# rnn = nn.RNN(10, 20, 2)
# input = torch.randn(5, 3, 10)
# h0 = torch.randn(2, 3, 20)
# output, hn = rnn(input, h0)
# print(output.shape,hn.shape)
# torch.Size([5, 3, 20]) torch.Size([2, 3, 20])
 
# 双向lstm例子
rnn = nn.LSTM(10, 20, 2)   #(input_size,hidden_size,num_layers)
input = torch.randn(5, 3, 10)  #(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 20)    #(num_layers * num_directions, batch, hidden_size)
c0 = torch.randn(2, 3, 20)    #(num_layers * num_directions, batch, hidden_size)
# output:(seq_len, batch, num_directions * hidden_size)
# hn,cn(num_layers * num_directions, batch, hidden_size)
output, (hn, cn) = rnn(input, (h0, c0)) 
 
print(output.shape,hn.shape,cn.shape)
>>>torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])

以上这篇基于pytorch的lstm参数使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的pprint折腾记
Jan 21 Python
python遍历数组的方法小结
Apr 30 Python
Python连接MySQL并使用fetchall()方法过滤特殊字符
Mar 13 Python
详解Python的Flask框架中生成SECRET_KEY密钥的方法
Jun 07 Python
利用Python实现图书超期提醒
Aug 02 Python
深入解答关于Python的11道基本面试题
Apr 01 Python
Python实现的用户登录系统功能示例
Feb 05 Python
python用插值法绘制平滑曲线
Feb 19 Python
Python利用WMI实现ping命令的例子
Aug 14 Python
Python socket处理client连接过程解析
Mar 18 Python
Python Django搭建网站流程图解
Jun 13 Python
python os模块在系统管理中的应用
Jun 22 Python
Python利用逻辑回归模型解决MNIST手写数字识别问题详解
Jan 14 #Python
np.random.seed() 的使用详解
Jan 14 #Python
下载与当前Chrome对应的chromedriver.exe(用于python+selenium)
Jan 14 #Python
Python selenium 自动化脚本打包成一个exe文件(推荐)
Jan 14 #Python
pytorch+lstm实现的pos示例
Jan 14 #Python
Python中sorted()排序与字母大小写的问题
Jan 14 #Python
Pytorch实现LSTM和GRU示例
Jan 14 #Python
You might like
百度地图经纬度转换到腾讯地图/Google 对应的经纬度
2015/08/28 PHP
laravel框架创建授权策略实例分析
2019/11/22 PHP
解读IE和firefox下JScript和HREF的执行顺序
2008/01/12 Javascript
Javascript中定义方法的另类写法(批量定义js对象的方法)
2011/02/25 Javascript
js设置function参数默认值(适合没有传参情况)
2014/02/24 Javascript
jQuery实现“扫码阅读”功能
2015/01/21 Javascript
在Node.js应用中使用Redis的方法简介
2015/06/24 Javascript
JQuery日历插件My97DatePicker日期范围限制
2016/01/20 Javascript
图文详解Heap Sort堆排序算法及JavaScript的代码实现
2016/05/04 Javascript
Bootstrap页面布局基础知识全面解析
2016/06/13 Javascript
jquery判断类型是不是number类型的实例代码
2016/10/07 Javascript
vue中页面跳转拦截器的实现方法
2017/08/23 Javascript
seajs下require书写约定实例分析
2018/05/16 Javascript
Django使用详解:ORM 的反向查找(related_name)
2018/05/30 Python
python3.6使用urllib完成下载的实例
2018/12/19 Python
python+selenium实现简历自动刷新的示例代码
2019/05/20 Python
Numpy 中的矩阵求逆实例
2019/08/26 Python
Python实现的爬取豆瓣电影信息功能案例
2019/09/15 Python
Pytest mark使用实例及原理解析
2020/02/22 Python
Django数据结果集序列化并展示实现过程
2020/04/22 Python
用python打开摄像头并把图像传回qq邮箱(Pyinstaller打包)
2020/05/17 Python
html5实现输入框fixed定位在屏幕最底部兼容性
2020/07/03 HTML / CSS
瑞典时尚耳机品牌:Urbanears
2017/07/26 全球购物
世界上最受欢迎的钓鱼诱饵:Rapala
2019/05/02 全球购物
马来西亚户外装备商店:PTT Outdoor
2019/07/13 全球购物
交通事故调解协议书
2014/04/16 职场文书
保洁公司服务承诺书
2014/05/28 职场文书
学习教师法的心得体会
2014/09/03 职场文书
八项规定个人对照检查材料思想汇报
2014/09/25 职场文书
小学教师节活动总结
2015/03/20 职场文书
2019年大学生职业生涯规划书
2019/03/25 职场文书
什么是求职信?求职信应包含哪些内容?
2019/08/14 职场文书
Linux安装apache服务器的配置过程
2021/11/27 Servers
NGINX 权限控制文件预览和下载的实现原理
2022/01/18 Servers
俄罗斯十大城市人口排名,第三首都仅排第六,第二是北方首都
2022/03/20 杂记
电脑关机速度很慢怎么办 提升电脑关机速度设置教程
2022/04/08 数码科技