pytorch 如何使用batch训练lstm网络


Posted in Python onMay 28, 2021

batch的lstm

# 导入相应的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data 
torch.manual_seed(1) 
 
# 准备数据的阶段
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)
  
with open("/home/lstm_train.txt", encoding='utf8') as f:
    train_data = []
    word = []
    label = []
    data = f.readline().strip()
    while data:
        data = data.strip()
        SP = data.split(' ')
        if len(SP) == 2:
            word.append(SP[0])
            label.append(SP[1])
        else:
            if len(word) == 100 and 'I-PRO' in label:
                train_data.append((word, label))
            word = []
            label = []
        data = f.readline()
 
word_to_ix = {}
for sent, _ in train_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
 
tag_to_ix = {"O": 0, "I-PRO": 1}
for i in range(len(train_data)):
    train_data[i] = ([word_to_ix[t] for t in train_data[i][0]], [tag_to_ix[t] for t in train_data[i][1]])
 
# 词向量的维度
EMBEDDING_DIM = 128
 
# 隐藏层的单元数
HIDDEN_DIM = 128
 
# 批大小
batch_size = 10  
class LSTMTagger(nn.Module):
 
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, batch_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
 
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
 
        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
 
    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        # input_tensor = embeds.view(self.batch_size, len(sentence) // self.batch_size, -1)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores
 
    def predict(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores 
 
loss_function = nn.NLLLoss()
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), batch_size)
optimizer = optim.SGD(model.parameters(), lr=0.1)
 
data_set_word = []
data_set_label = []
for data_tuple in train_data:
    data_set_word.append(data_tuple[0])
    data_set_label.append(data_tuple[1])
torch_dataset = Data.TensorDataset(torch.tensor(data_set_word, dtype=torch.long), torch.tensor(data_set_label, dtype=torch.long))
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,  # torch TensorDataset format
    batch_size=batch_size,  # mini batch size
    shuffle=True,  #
    num_workers=2,  # 多线程来读数据
)
 
# 训练过程
for epoch in range(200):
    for step, (batch_x, batch_y) in enumerate(loader):
        # 梯度清零
        model.zero_grad()
        tag_scores = model(batch_x)
 
        # 计算损失
        tag_scores = tag_scores.view(-1, tag_scores.shape[2])
        batch_y = batch_y.view(batch_y.shape[0]*batch_y.shape[1])
        loss = loss_function(tag_scores, batch_y)
        print(loss)
        # 后向传播
        loss.backward()
 
        # 更新参数
        optimizer.step()
 
# 测试过程
with torch.no_grad():
    inputs = torch.tensor([data_set_word[0]], dtype=torch.long)
    print(inputs)
    tag_scores = model.predict(inputs)
    print(tag_scores.shape)
    print(torch.argmax(tag_scores, dim=2))

补充:PyTorch基础-使用LSTM神经网络实现手写数据集识别

看代码吧~

import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置
                            train = True, # 载入训练集
                            transform=transforms.ToTensor(), # 把数据变成tensor类型
                            download = True # 下载
                           )
# 测试集
test_data = datasets.MNIST(root="./",
                            train = False,
                            transform=transforms.ToTensor(),
                            download = True
                           )
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):
    inputs,labels = data
    print(inputs.shape)
    print(labels.shape)
    break
# 定义网络结构
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM,self).__init__()# 初始化
        self.lstm = torch.nn.LSTM(
            input_size = 28, # 表示输入特征的大小
            hidden_size = 64, # 表示lstm模块的数量
            num_layers = 1, # 表示lstm隐藏层的层数
            batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature)
        )
        self.out = torch.nn.Linear(in_features=64,out_features=10)
        self.softmax = torch.nn.Softmax(dim=1)
    def forward(self,x):
        # (batch,seq_len,feature)
        x = x.view(-1,28,28)
        # output:(batch,seq_len,hidden_size)包含每个序列的输出结果
        # 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers
        # h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        # c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        output,(h_n,c_n) = self.lstm(x)
        output_in_last_timestep = h_n[-1,:,:]
        x = self.out(output_in_last_timestep)
        x = self.softmax(x)
        return x
# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
# 定义模型训练和测试的方法
def train():
    # 模型的训练状态
    model.train()
    for i,data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 交叉熵代价函数out(batch,C:类别的数量),labels(batch)
        loss = mse_loss(out,labels)
        # 梯度清零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()
        
def test():
    # 模型的测试状态
    model.eval()
    correct = 0 # 测试集准确率
    for i,data in enumerate(test_loader):
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 获得最大值,以及最大值所在的位置
        _,predicted = torch.max(out,1)
        # 预测正确的数量
        correct += (predicted==labels).sum()
    print("Test acc:{0}".format(correct.item()/len(test_data)))
    
    correct = 0
    for i,data in enumerate(train_loader): # 训练集准确率
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 获得最大值,以及最大值所在的位置
        _,predicted = torch.max(out,1)
        # 预测正确的数量
        correct += (predicted==labels).sum()
    print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(10):
    print("epoch:",epoch)
    train()
    test()

pytorch 如何使用batch训练lstm网络

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3实现连接SQLite数据库的方法
Aug 23 Python
Python入门篇之编程习惯与特点
Oct 17 Python
Python文档生成工具pydoc使用介绍
Jun 02 Python
Python模拟脉冲星伪信号频率实例代码
Jan 03 Python
对Tensorflow中的矩阵运算函数详解
Jul 27 Python
简单介绍python封装的基本知识
Aug 10 Python
python 一篇文章搞懂装饰器所有用法(建议收藏)
Aug 23 Python
python如何基于redis实现ip代理池
Jan 17 Python
python 实现字符串下标的输出功能
Feb 13 Python
Python运行异常管理解决方案
Mar 09 Python
python操作链表的示例代码
Sep 27 Python
Python中tkinter的用户登录管理的实现
Apr 22 Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
You might like
PHP中在数据库中保存Checkbox数据(1)
2006/10/09 PHP
php字符编码转换之gb2312转为utf8
2013/10/28 PHP
php 多文件上传的实现实例
2016/10/23 PHP
动态加载js的几种方法
2006/10/23 Javascript
Javascript条件判断使用小技巧总结
2008/09/08 Javascript
基于jQuery的的一个隔行变色,鼠标移动变色的小插件
2010/07/06 Javascript
JavaScript中使用构造函数实现继承的代码
2010/08/12 Javascript
深入理解JavaScript系列(39):设计模式之适配器模式详解
2015/03/04 Javascript
JavaScript判断是否为数组的3种方法及效率比较
2015/04/01 Javascript
js实现使用鼠标拖拽切换图片的方法
2015/05/04 Javascript
jQuery简单实现日历的方法
2015/05/04 Javascript
jquery遍历标签中自定义的属性方法
2016/09/17 Javascript
关于webuploader插件使用过程遇到的小问题
2016/11/07 Javascript
JSONP基础知识详解
2017/03/19 Javascript
JS简单生成随机数(随机密码)的方法
2017/05/11 Javascript
vue 怎么创建组件及组件使用方法
2017/07/27 Javascript
浅谈Vue.js中ref ($refs)用法举例总结
2017/12/19 Javascript
详解vue-meta如何让你更优雅的管理头部标签
2018/01/18 Javascript
js实现跟随鼠标移动的小球
2019/08/26 Javascript
vue实现页面内容禁止选中功能,仅输入框和文本域可选
2019/11/09 Javascript
基于JavaScript的数据结构队列动画实现示例解析
2020/08/06 Javascript
[41:52]2018DOTA2亚洲邀请赛3月29日 小组赛A组 TNC VS OpTic
2018/03/30 DOTA
python模拟新浪微博登陆功能(新浪微博爬虫)
2013/12/24 Python
TensorFlow模型保存/载入的两种方法
2018/03/08 Python
Python格式化日期时间操作示例
2018/06/28 Python
实例讲解Python脚本成为Windows中运行的exe文件
2019/01/24 Python
Pytorch 实现自定义参数层的例子
2019/08/17 Python
Python pip 安装与使用(安装、更新、删除)
2019/10/06 Python
基于梯度爆炸的解决方法:clip gradient
2020/02/04 Python
Django 解决distinct无法去除重复数据的问题
2020/05/20 Python
SportsDirect.com马来西亚:英国第一体育零售商
2018/11/21 全球购物
学生自我评语大全
2014/04/18 职场文书
推普周国旗下讲话稿
2014/09/21 职场文书
小学生大队委竞选稿
2015/11/20 职场文书
导游词之蜀山胜景瓦屋山
2019/11/29 职场文书
canvas实现贪食蛇的实践
2022/02/15 Javascript