Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用异步Socket编程性能测试
Jun 25 Python
跟老齐学Python之用while来循环
Oct 02 Python
用Python的Flask框架结合MySQL写一个内存监控程序
Nov 07 Python
pycharm下查看python的变量类型和变量内容的方法
Jun 26 Python
python处理数据,存进hive表的方法
Jul 04 Python
django用户登录验证的完整示例代码
Jul 21 Python
Laravel框架表单验证格式化输出的方法
Sep 25 Python
使用Python测试Ping主机IP和某端口是否开放的实例
Dec 17 Python
pytorch随机采样操作SubsetRandomSampler()
Jul 07 Python
PyTorch安装与基本使用详解
Aug 31 Python
Python 串口通信的实现
Sep 29 Python
python实现自动化群控的步骤
Apr 11 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
解析yii数据库的增删查改
2013/06/20 PHP
详解WordPress开发中的get_post与get_posts函数使用
2016/01/04 PHP
php生成验证码,缩略图及水印图的类分享
2016/04/07 PHP
Laravel 中使用 Vue.js 实现基于 Ajax 的表单提交错误验证操作
2017/06/30 PHP
浅谈PHP发送HTTP请求的几种方式
2017/07/25 PHP
Laravel Intervention/image图片处理扩展包的安装、使用与可能遇到的坑详解
2017/11/14 PHP
PHP开发中解决并发问题的几种实现方法分析
2017/11/13 PHP
PHP使用HTML5 FileApi实现Ajax上传文件功能示例
2019/07/01 PHP
php常用日期时间函数实例小结
2019/07/04 PHP
Laravel框架实现即点即改功能的方法分析
2019/10/31 PHP
IE DOM实现存在的部分问题及解决方法
2009/07/25 Javascript
JQuery Tips相关(1)----关于$.Ready()
2014/08/14 Javascript
JS中的phototype详解
2017/02/04 Javascript
微信小程序 图片加载(本地,网路)实例详解
2017/03/10 Javascript
vue通过cookie获取用户登录信息的思路详解
2018/10/30 Javascript
JS函数动态传递参数的方法分析【基于arguments对象】
2019/06/05 Javascript
Vue实现购物车详情页面的方法
2019/08/20 Javascript
Vue+Element-UI实现上传图片并压缩
2019/11/26 Javascript
JS获取表格视图所选行号的ids过程解析
2020/02/21 Javascript
vue使用nprogress加载路由进度条的方法
2020/06/04 Javascript
[01:51]2014DOTA2国际邀请赛 这个赛场没有失败者VGTi5再见
2014/07/23 DOTA
Python实现的检测网站挂马程序
2014/11/30 Python
Python的装饰器用法学习笔记
2016/06/24 Python
Python中的Numpy矩阵操作
2018/08/12 Python
在pycharm下设置自己的个性模版方法
2019/07/15 Python
简单了解python调用其他脚本方法实例
2020/03/26 Python
Python使用windows设置定时执行脚本
2020/11/12 Python
高清安全摄像头系统:Lorex Technology
2018/07/20 全球购物
周仰杰(JIMMY CHOO)法国官方网站:闻名世界的鞋子品牌
2019/09/27 全球购物
什么是makefile? 如何编写makefile?
2013/01/02 面试题
保安公司服务承诺书
2014/05/28 职场文书
生活小常识广播稿
2014/09/16 职场文书
六一儿童节主持开场白
2015/05/28 职场文书
2016年中秋节晚会领导致辞
2015/11/26 职场文书
幼师自荐信范文(2016推荐篇)
2016/01/28 职场文书
关于html选择框创建占位符的问题
2021/06/09 HTML / CSS