Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序设计入门(3)数组的使用
Jun 16 Python
使用Node.js和Socket.IO扩展Django的实时处理功能
Apr 20 Python
python获得文件创建时间和修改时间的方法
Jun 30 Python
通过数据库向Django模型添加字段的示例
Jul 21 Python
浅谈Python中的可变对象和不可变对象
Jul 07 Python
Python 处理数据的实例详解
Aug 10 Python
Python实现的概率分布运算操作示例
Aug 14 Python
Python实现连接两个无规则列表后删除重复元素并升序排序的方法
Feb 05 Python
Pycharm2017版本设置启动时默认自动打开项目的方法
Oct 29 Python
Python PIL读取的图像发生自动旋转的实现方法
Jul 05 Python
Window10下python3.7 安装与卸载教程图解
Sep 30 Python
python使用html2text库实现从HTML转markdown的方法详解
Feb 21 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
电脑硬件及电脑配置知识大全
2020/03/17 数码科技
smarty模板引擎之内建函数用法
2015/03/30 PHP
docker-compose部署php项目实例详解
2019/07/30 PHP
php5.3/5.4/5.5/5.6/7常见新增特性汇总整理
2020/02/27 PHP
jquery ajax 检测用户注册时用户名是否存在
2009/11/03 Javascript
node.js中的path.extname方法使用说明
2014/12/09 Javascript
基于javascript、ajax、memcache和PHP实现的简易在线聊天室
2015/02/03 Javascript
JavaScript如何调试有哪些建议和技巧附五款有用的调试工具
2015/10/28 Javascript
jQuery实现form表单元素序列化为json对象的方法
2015/12/09 Javascript
canvas绘制七巧板
2017/02/03 Javascript
十大热门的JavaScript框架和库
2017/03/21 Javascript
jQuery Plupload上传插件的使用
2017/04/19 jQuery
打通前后端构建一个Vue+Express的开发环境
2018/07/17 Javascript
在element-ui的select下拉框加上滚动加载
2019/04/18 Javascript
jQuery实现条件搜索查询、实时取值及升降序排序的方法分析
2019/05/04 jQuery
微信小程序实现手势滑动效果
2019/08/26 Javascript
Python标准库urllib2的一些使用细节总结
2015/03/16 Python
举例讲解Python中的算数运算符的用法
2015/05/13 Python
Python实现的寻找前5个默尼森数算法示例
2018/03/25 Python
pandas 取出表中一列数据所有的值并转换为array类型的方法
2018/04/11 Python
使用Python正则表达式操作文本数据的方法
2019/05/14 Python
解决Python import docx出错DLL load failed的问题
2020/02/13 Python
python数据分析工具之 matplotlib详解
2020/04/09 Python
CSS3的Flexbox布局的简明入门指南
2016/04/08 HTML / CSS
CSS3的first-child选择器实战攻略
2016/04/28 HTML / CSS
html5使用canvas画空心圆与实心圆
2014/12/15 HTML / CSS
SheIn俄罗斯:时尚女装网上商店
2017/02/28 全球购物
Happy Socks英国官网:购买五颜六色的袜子
2020/11/03 全球购物
中学生学雷锋演讲稿
2014/04/26 职场文书
创文明城市标语
2014/06/16 职场文书
2014年党风建设工作总结
2014/11/19 职场文书
医院科室评语
2015/01/04 职场文书
2015年乡镇科普工作总结
2015/05/13 职场文书
TensorFlow的自动求导原理分析
2021/05/26 Python
总结python多进程multiprocessing的相关知识
2021/06/29 Python
MySQL提取JSON字段数据实现查询
2022/04/22 MySQL