Pytorch实现LSTM和GRU示例


Posted in Python onJanuary 14, 2020

为了解决传统RNN无法长时依赖问题,RNN的两个变体LSTM和GRU被引入。

LSTM

Long Short Term Memory,称为长短期记忆网络,意思就是长的短时记忆,其解决的仍然是短时记忆问题,这种短时记忆比较长,能一定程度上解决长时依赖。

Pytorch实现LSTM和GRU示例

上图为LSTM的抽象结构,LSTM由3个门来控制,分别是输入门、遗忘门和输出门。输入门控制网络的输入,遗忘门控制着记忆单元,输出门控制着网络的输出。最为重要的就是遗忘门,可以决定哪些记忆被保留,由于遗忘门的作用,使得LSTM具有长时记忆的功能。对于给定的任务,遗忘门能够自主学习保留多少之前的记忆,网络能够自主学习。

具体看LSTM单元的内部结构:

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

Pytorch实现LSTM和GRU示例

在每篇文章中,作者都会使用和标准LSTM稍微不同的版本,针对特定的任务,特定的网络结构往往表现更好。

GRU

Pytorch实现LSTM和GRU示例

上述的过程的线性变换没有使用偏置。隐藏状态参数不再是标准RNN的4倍,而是3倍,也就是GRU的参数要比LSTM的参数量要少,但是性能差不多。

Pytorch

在Pytorch中使用nn.LSTM()可调用,参数和RNN的参数相同。具体介绍LSTM的输入和输出:

输入: input, (h_0, c_0)

input:输入数据with维度(seq_len,batch,input_size)

h_0:维度为(num_layers*num_directions,batch,hidden_size),在batch中的

初始的隐藏状态.

c_0:初始的单元状态,维度与h_0相同

输出:output, (h_n, c_n)

output:维度为(seq_len, batch, num_directions * hidden_size)。

h_n:最后时刻的输出隐藏状态,维度为 (num_layers * num_directions, batch, hidden_size)

c_n:最后时刻的输出单元状态,维度与h_n相同。

LSTM的变量:

Pytorch实现LSTM和GRU示例

以MNIST分类为例实现LSTM分类

MNIST图片大小为28×28,可以将每张图片看做是长为28的序列,序列中每个元素的特征维度为28。将最后输出的隐藏状态Pytorch实现LSTM和GRU示例 作为抽象的隐藏特征输入到全连接层进行分类。最后输出的

导入头文件:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
  def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
    super(Rnn, self).__init__()
    self.n_layer = n_layer
    self.hidden_dim = hidden_dim
    self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
    self.classifier = nn.Linear(hidden_dim, n_classes)

  def forward(self, x):
    out, (h_n, c_n) = self.lstm(x)
    # 此时可以从out中获得最终输出的状态h
    # x = out[:, -1, :]
    x = h_n[-1, :, :]
    x = self.classifier(x)
    return x

训练和测试代码:

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize([0.5], [0.5]),
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

net = Rnn(28, 10, 2, 10)

net = net.to('cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

# Training
def train(epoch):
  print('\nEpoch: %d' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs, targets = inputs.to('cpu'), targets.to('cpu')
    optimizer.zero_grad()
    outputs = net(torch.squeeze(inputs, 1))
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
  global best_acc
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(testloader):
      inputs, targets = inputs.to('cpu'), targets.to('cpu')
      outputs = net(torch.squeeze(inputs, 1))
      loss = criterion(outputs, targets)

      test_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()

      print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))




for epoch in range(200):
  train(epoch)
  test(epoch)

以上这篇Pytorch实现LSTM和GRU示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
测试、预发布后用python检测网页是否有日常链接
Jun 03 Python
使用Python中的cookielib模拟登录网站
Apr 09 Python
python简单读取大文件的方法
Jul 01 Python
Python之py2exe打包工具详解
Jun 14 Python
Python实现图片尺寸缩放脚本
Mar 10 Python
Python 打印中文字符的三种方法
Aug 14 Python
python实现将视频按帧读取到自定义目录
Dec 10 Python
Python.append()与Python.expand()用法详解
Dec 18 Python
Python urlopen()和urlretrieve()用法解析
Jan 07 Python
Django限制API访问频率常用方法解析
Oct 12 Python
python将YUV420P文件转PNG图片格式的两种方法
Jan 22 Python
Python虚拟环境virtualenv是如何使用的
Jun 20 Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
You might like
SONY ICF-SW07收音机电路分析
2021/03/02 无线电
用缓存实现静态页面的测试
2006/12/06 PHP
php使用unset()删除数组中某个单元(键)的方法
2015/02/17 PHP
php判断IP地址是否在多个IP段内
2020/08/18 PHP
js网页侧边随页面滚动广告效果实现
2011/04/14 Javascript
jquery判断元素是否隐藏的多种方法
2014/05/06 Javascript
JS实现回到页面顶部动画效果的简单实例
2016/05/24 Javascript
Bootstrap表格使用方法详解
2017/02/17 Javascript
js 客户端打印html 并且去掉页眉、页脚的实例
2017/11/03 Javascript
浅析Vue自定义组件的v-model
2017/11/26 Javascript
Vue 实现展开折叠效果的示例代码
2018/08/27 Javascript
Vue.js路由实现选项卡简单实例
2019/07/24 Javascript
微信小程序获取公众号文章列表及显示文章的示例代码
2020/03/10 Javascript
python连接oracle数据库实例
2014/10/17 Python
python获取时间及时间格式转换问题实例代码详解
2018/12/06 Python
对Python3 pyc 文件的使用详解
2019/02/16 Python
opencv python图像梯度实例详解
2020/02/04 Python
使用 django orm 写 exists 条件过滤实例
2020/05/20 Python
Python 创建TCP服务器的方法
2020/07/28 Python
利用Python实现字幕挂载(把字幕文件与视频合并)思路详解
2020/10/21 Python
python中的split、rsplit、splitlines用法说明
2020/10/23 Python
Python 中的函数装饰器和闭包详解
2021/02/06 Python
10分钟理解CSS3 Grid布局
2018/12/20 HTML / CSS
波兰最大的儿童服装连锁店之一:5.10.15.
2018/02/11 全球购物
Mountain Hardwear官网:攀岩服装和户外装备
2019/09/26 全球购物
法国包包和行李箱销售网站:Bagage24.fr
2020/03/24 全球购物
知识改变命运演讲稿
2014/05/21 职场文书
贫困证明模板(3篇)
2014/09/16 职场文书
未中标通知书
2015/04/17 职场文书
JS监听Esc 键触发事键
2021/04/14 Javascript
python基于tkinter实现gif录屏功能
2021/05/19 Python
python opencv旋转图片的使用方法
2021/06/04 Python
教你用Python爬取英雄联盟皮肤原画
2021/06/13 Python
MongoDB连接数据库并创建数据等使用方法
2021/11/27 MongoDB
Element-ui Layout布局(Row和Col组件)的实现
2021/12/06 Vue.js
windows11选中自动复制怎么开启? Win11自动复制所选内容的方法
2022/07/23 数码科技