Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python选择排序、冒泡排序、合并排序代码实例
Apr 10 Python
在Python中用has_key()方法查找键是否存在的教程
May 21 Python
python实现自动重启本程序的方法
Jul 09 Python
举例讲解Python编程中对线程锁的使用
Jul 12 Python
Python设计模式之门面模式简单示例
Jan 09 Python
python调用百度语音REST API
Aug 30 Python
Python实现堡垒机模式下远程命令执行操作示例
May 09 Python
python binascii 进制转换实例
Jun 12 Python
浅谈PySpark SQL 相关知识介绍
Jun 14 Python
树莓派3 搭建 django 服务器的实例
Aug 29 Python
centos7中安装python3.6.4的教程
Dec 11 Python
Python&Matlab实现灰狼优化算法的示例代码
Mar 21 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
php-fpm重启导致的程序执行中断问题详解
2019/04/29 PHP
PHP使用PhpSpreadsheet操作Excel实例详解
2020/03/26 PHP
PHP如何使用array_unshift()在数组开头插入元素
2020/09/01 PHP
HTML复选框和单选框 checkbox和radio事件介绍
2012/12/12 Javascript
jquery select动态加载选择(兼容各种浏览器)
2013/02/01 Javascript
JS中eval函数的使用示例
2013/07/21 Javascript
js实现身份证号码验证的简单实例
2014/02/19 Javascript
JQuery对表格进行操作的常用技巧总结
2014/04/23 Javascript
JavaScript获取页面中表单(form)数量的方法
2015/04/03 Javascript
JavaScript生成福利彩票双色球号码
2015/05/15 Javascript
jQuery中的一些小技巧
2017/01/18 Javascript
HTML5+Canvas调用手机拍照功能实现图片上传(下)
2017/04/21 Javascript
jquery.rotate.js实现可选抽奖次数和中奖内容的转盘抽奖代码
2017/08/23 jQuery
jQuery UI Draggable + Sortable 结合使用(实例讲解)
2017/09/07 jQuery
vue使用v-for实现hover点击效果
2018/09/29 Javascript
express+vue+mongodb+session 实现注册登录功能
2018/12/06 Javascript
如何解决webpack-dev-server代理常切换问题
2019/01/09 Javascript
[04:55]完美世界副总裁蔡玮:DOTA2的自由、公平与信任
2013/12/18 DOTA
[04:54]DOTA2-DPC中国联赛1月31日Recap集锦
2021/03/11 DOTA
python基于Tkinter库实现简单文本编辑器实例
2015/05/05 Python
python实现发送邮件及附件功能
2021/03/02 Python
python中验证码连通域分割的方法详解
2018/06/04 Python
Django之模型层多表操作的实现
2019/01/08 Python
pycharm 实现显示project 选项卡的方法
2019/01/17 Python
Django实现微信小程序的登录验证功能并维护登录态
2019/07/04 Python
python读取ini配置的类封装代码实例
2020/01/08 Python
canvas 实现 github404动态效果的示例代码
2017/11/15 HTML / CSS
美国饼干礼物和美食甜点购买网站:Cheryl’s
2020/05/28 全球购物
车祸赔偿收入证明
2014/01/09 职场文书
工程造价专业大学生职业生涯规划书
2014/01/18 职场文书
英文留学推荐信范文
2014/01/25 职场文书
环境工程专业毕业生求职信
2014/09/30 职场文书
酒店收银员岗位职责
2015/04/07 职场文书
导游词之镇江-金山寺
2019/10/14 职场文书
怎么用Python识别手势数字
2021/06/07 Python
ssh服务器拒绝了密码 请再试一次已解决(亲测有效)
2022/08/14 Servers