Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
linux系统使用python获取cpu信息脚本分享
Jan 15 Python
python判断、获取一张图片主色调的2个实例
Apr 10 Python
MAC中PyCharm设置python3解释器
Dec 15 Python
python scatter散点图用循环分类法加图例
Mar 19 Python
pytorch 共享参数的示例
Aug 17 Python
Python 中的 import 机制之实现远程导入模块
Oct 29 Python
如何在django中添加日志功能
Feb 06 Python
解决python脚本中error: unrecognized arguments: True错误
Apr 20 Python
python 实现读取csv数据,分类求和 再写进 csv
May 18 Python
Django框架请求生命周期实现原理
Nov 13 Python
python中绕过反爬虫的方法总结
Nov 25 Python
python可视化之颜色映射详解
Sep 15 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
php获取字符串中各个字符出现次数的方法
2015/02/23 PHP
win7系统配置php+Apache+mysql环境的方法
2015/08/21 PHP
PHP实现的随机红包算法示例
2017/08/14 PHP
JavaScript 封装Ajax传递的数据代码
2009/06/05 Javascript
再说AutoComplete自动补全之实现原理
2011/11/05 Javascript
jQuery Tools tab(幻灯片)
2012/07/14 Javascript
js拦截alert对话框另类应用
2013/01/16 Javascript
JavaScript学习小结(一)——JavaScript入门基础
2015/09/02 Javascript
BootStrap的弹出框(Popover)支持鼠标移到弹出层上弹窗层不隐藏的原因及解决办法
2016/04/03 Javascript
分享javascript实现的冒泡排序代码并优化
2016/06/05 Javascript
老生常谈原生JS执行环境与作用域
2016/11/22 Javascript
微信小程序调用PHP后台接口 解析纯html文本
2017/06/13 Javascript
vue中遇到的坑之变化检测问题(数组相关)
2017/10/13 Javascript
深入理解node.js http模块
2018/01/24 Javascript
Vue 组件参数校验与非props特性的方法
2019/02/12 Javascript
如何在Angular应用中创建包含组件方法示例
2019/03/23 Javascript
vue使用代理解决请求跨域问题详解
2019/07/24 Javascript
vue 判断两个时间插件结束时间必选大于开始时间的代码
2020/11/04 Javascript
在Python中使用pngquant压缩png图片的教程
2015/04/09 Python
Python函数式编程指南(三):迭代器详解
2015/06/24 Python
Python 实现交换矩阵的行示例
2019/06/26 Python
python中时间、日期、时间戳的转换的实现方法
2019/07/06 Python
学python需要去培训机构吗
2020/07/01 Python
html5指南-2.如何操作document metadata
2013/01/07 HTML / CSS
美体小铺瑞典官方网站:The Body Shop瑞典
2018/01/27 全球购物
俄罗斯马克西多姆家居用品网上商店:Максидом
2020/02/06 全球购物
酒店前台接待岗位职责
2013/12/03 职场文书
金融行业职业生涯规划范文
2014/01/17 职场文书
专科应届毕业生求职信
2014/06/04 职场文书
2014购房个人委托书范本
2014/10/12 职场文书
爱牙日宣传活动总结
2015/02/05 职场文书
个人党性锻炼总结
2015/03/05 职场文书
倡议书范文大全
2015/04/28 职场文书
python pygame入门教程
2021/06/01 Python
Oracle11g R2 安装教程完整版
2021/06/04 Oracle
Java 通过手写分布式雪花SnowFlake生成ID方法详解
2022/04/07 Java/Android