Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
仅利用30行Python代码来展示X算法
Apr 01 Python
Python抓取百度查询结果的方法
Jul 08 Python
Python操作MongoDB数据库的方法示例
Jan 04 Python
python将字典内容存入mysql实例代码
Jan 18 Python
Python hashlib模块用法实例分析
Jun 12 Python
python得到一个excel的全部sheet标签值方法
Dec 10 Python
python七夕浪漫表白源码
Apr 05 Python
python 判断字符串中是否含有汉字或非汉字的实例
Jul 15 Python
Python图像读写方法对比
Nov 16 Python
python 爬虫请求模块requests详解
Dec 04 Python
python 获取计算机的网卡信息
Feb 18 Python
Django实现drf搜索过滤和排序过滤
Jun 21 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
php 数组的一个悲剧?
2011/05/11 PHP
php 文章调用类代码
2011/08/11 PHP
PHP eval函数使用介绍
2013/12/08 PHP
PHP搭建大文件切割分块上传功能示例
2017/01/04 PHP
PHP代码加密的方法总结
2020/03/13 PHP
ext checkboxgroup 回填数据解决
2009/08/21 Javascript
jquery中实现简单的tabs插件功能的代码
2011/03/02 Javascript
简单的jquery拖拽排序效果实现代码
2011/09/20 Javascript
js 数组去重的四种实用方法
2014/09/09 Javascript
基于js与flash实现的网站flv视频播放插件代码
2014/10/14 Javascript
详解JavaScript基于面向对象之继承实例
2015/12/16 Javascript
如何使用AngularJs打造权限管理系统【简易型】
2016/05/09 Javascript
前端程序员必须知道的高性能Javascript知识
2016/08/24 Javascript
AngularJS中的Promise详细介绍及实例代码
2016/12/13 Javascript
解决Angular.Js与Django标签冲突的方案
2016/12/20 Javascript
jQuery实现的简单在线计算器功能
2017/05/11 jQuery
Ionic项目中Native Camera的使用方法
2017/06/07 Javascript
原生js中ajax访问的实例详解
2017/09/19 Javascript
简单的Vue SSR的示例代码
2018/01/12 Javascript
js实现时钟定时器
2020/03/26 Javascript
跟老齐学Python之??碌某?? target=
2014/09/12 Python
在Django中使用Sitemap的方法讲解
2015/07/22 Python
Python中特殊函数集锦
2015/07/27 Python
python如何通过protobuf实现rpc
2016/03/06 Python
python redis 删除key脚本的实例
2019/02/19 Python
python爬虫学习笔记之Beautifulsoup模块用法详解
2020/04/09 Python
Python字节单位转换(将字节转换为K M G T)
2021/03/02 Python
css3.0 图形构成实例练习二
2013/03/19 HTML / CSS
GE设备配件:GE Appliance Parts(家电零件、配件和滤水器)
2018/11/28 全球购物
使用索引(Index)有哪些需要考虑的因素
2016/10/19 面试题
公司营业员的工作总结自我评价
2013/10/05 职场文书
服务员岗位责任制
2014/02/11 职场文书
学院党的群众路线教育实践活动第一阶段情况汇报
2014/10/25 职场文书
2015年学校教务处工作总结
2015/05/11 职场文书
利用Python第三方库实现预测NBA比赛结果
2021/06/21 Python
MySQL系列之二 多实例配置
2021/07/02 MySQL