Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python求解物理学中的双弹簧质能系统详解
Sep 29 Python
python编程羊车门问题代码示例
Oct 25 Python
在matplotlib的图中设置中文标签的方法
Dec 13 Python
对python生成业务报表的实例详解
Feb 03 Python
基于python生成器封装的协程类
Mar 20 Python
Python面向对象思想与应用入门教程【类与对象】
Apr 12 Python
计算机二级python学习教程(2) python语言基本语法元素
May 16 Python
利用Django模版生成树状结构实例代码
May 19 Python
NumPy统计函数的实现方法
Jan 21 Python
python 实现批量图片识别并翻译
Nov 02 Python
python 实现汉诺塔游戏
Nov 28 Python
python区块链实现简版工作量证明
May 25 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
全国FM电台频率大全 - 11 浙江省
2020/03/11 无线电
笑谈配置,使用Smarty技术
2007/01/04 PHP
PHP 日期时间函数的高级应用技巧
2009/10/10 PHP
PHP curl 抓取AJAX异步内容示例
2014/09/09 PHP
20个2014年最优秀的PHP框架回顾
2014/10/22 PHP
PHP中读取文件的几个方法总结(推荐)
2016/06/03 PHP
jquery+thinkphp实现跨域抓取数据的方法
2016/10/15 PHP
thinkphp Apache配置重启Apache1 restart 出错解决办法
2017/02/15 PHP
thinkPHP实现签到功能的方法
2017/03/15 PHP
三个思路解决laravel上传文件报错:413 Request Entity Too Large问题
2017/11/13 PHP
Yii 访问 Gii(脚手架)时出现 403 错误
2018/06/06 PHP
js获取html页面节点方法(递归方式)
2013/12/13 Javascript
推荐10 款 SVG 动画的 JavaScript 库
2015/03/24 Javascript
简介JavaScript中setUTCSeconds()方法的使用
2015/06/12 Javascript
JQuery自适应窗口大小导航菜单附源码下载
2015/09/01 Javascript
JS动态改变浏览器标题的方法
2016/04/06 Javascript
Android中Okhttp3实现上传多张图片同时传递参数
2017/02/18 Javascript
React Native时间转换格式工具类分享
2017/10/24 Javascript
浅谈MUI框架中加载外部网页或服务器数据的方法
2018/01/31 Javascript
vue实现扫码功能
2020/01/17 Javascript
jQuery实现颜色打字机的完整代码
2020/03/19 jQuery
python BeautifulSoup设置页面编码的方法
2015/04/03 Python
Python使用剪切板的方法
2017/06/06 Python
python 用正则表达式筛选文本信息的实例
2018/06/05 Python
python实现周期方波信号频谱图
2018/07/21 Python
详解PyCharm安装MicroPython插件的教程
2019/06/24 Python
Python turtle画图库&&画姓名实例
2020/01/19 Python
简单了解pytest测试框架setup和tearDown
2020/04/14 Python
Visual Studio Code搭建django项目的方法步骤
2020/09/17 Python
CSS3制作皮卡丘动画壁纸的示例
2020/11/02 HTML / CSS
HTML5使用Audio标签实现歌词同步的效果
2016/03/17 HTML / CSS
html5 canvas绘制网络字体的常用方法
2019/08/26 HTML / CSS
《再见了,亲人》教学反思
2014/02/26 职场文书
师范生求职信
2014/06/14 职场文书
幼儿园六一亲子活动方案
2014/08/26 职场文书
SQL Server数据库基本概念、组成、常用对象与约束
2022/03/20 SQL Server