Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python translator使用实例
Sep 06 Python
对python 通过ssh访问数据库的实例详解
Feb 19 Python
Python使用到第三方库PyMuPDF图片与pdf相互转换
May 03 Python
python 中的列表生成式、生成器表达式、模块导入
Jun 19 Python
pandas的to_datetime时间转换使用及学习心得
Aug 11 Python
Django rstful登陆认证并检查session是否过期代码实例
Aug 13 Python
tesserocr与pytesseract模块的使用方法解析
Aug 30 Python
tensorflow实现测试时读取任意指定的check point的网络参数
Jan 21 Python
Python如何执行系统命令
Sep 23 Python
安装pyinstaller遇到的各种问题(小结)
Nov 20 Python
python 如何上传包到pypi
Dec 24 Python
Python读写Excel表格的方法
Mar 02 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
输出控制类
2006/10/09 PHP
一个程序下载的管理程序(三)
2006/10/09 PHP
一段php加密解密的代码
2007/07/16 PHP
使用php来实现网络服务
2009/09/15 PHP
PHP中文件读、写、删的操作(PHP中对文件和目录操作)
2012/03/06 PHP
双冒号 ::在PHP中的使用情况
2015/11/05 PHP
PHP基于rabbitmq操作类的生产者和消费者功能示例
2018/06/16 PHP
jQuery Tools tab(幻灯片)
2012/07/14 Javascript
目前流行的JavaScript库的介绍及对比
2013/09/29 Javascript
JS截取url中问号后面参数的值信息
2014/04/29 Javascript
判断iframe里的页面是否加载完成
2014/06/06 Javascript
node.js使用npm 安装插件时提示install Error: ENOENT报错的解决方法
2014/11/20 Javascript
C#中使用迭代器处理等待任务
2015/07/13 Javascript
jQuery添加删除DOM元素方法详解
2016/01/18 Javascript
VUE + UEditor 单图片跨域上传功能的实现方法
2018/02/08 Javascript
nodejs+mongodb aggregate级联查询操作示例
2018/03/17 NodeJs
微信小程序学习笔记之登录API与获取用户信息操作图文详解
2019/03/29 Javascript
微信小程序实现卡片左右滑动效果的示例代码
2019/05/01 Javascript
Python 元类使用说明
2009/12/18 Python
python使用PyGame绘制图像并保存为图片文件的方法
2015/04/24 Python
使用Python的Tornado框架实现一个Web端图书展示页面
2016/07/11 Python
利用Python批量生成任意尺寸的图片
2016/08/29 Python
从运行效率与开发效率比较Python和C++
2018/12/14 Python
深入浅析python3中的unicode和bytes问题
2019/07/03 Python
通过Turtle库在Python中绘制一个鼠年福鼠
2020/02/03 Python
Jmeter HTTPS接口测试证书导入过程图解
2020/07/22 Python
舞会礼服和舞会鞋:PromGirl
2019/04/22 全球购物
P D PAOLA法国官网:西班牙著名的珠宝首饰品牌
2020/02/15 全球购物
巴西24小时在线药房:Drogasil
2020/06/20 全球购物
《放飞蜻蜓》教学反思
2014/04/27 职场文书
教师演讲稿开场白
2014/08/25 职场文书
创先争优演讲稿
2014/09/15 职场文书
介绍信模板
2015/01/31 职场文书
创业计划书之家政服务
2019/09/18 职场文书
Pandas||过滤缺失数据||pd.dropna()函数的用法说明
2021/05/14 Python
基于Redis的List实现特价商品列表功能
2021/08/30 Redis