pytorch 如何使用batch训练lstm网络


Posted in Python onMay 28, 2021

batch的lstm

# 导入相应的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data 
torch.manual_seed(1) 
 
# 准备数据的阶段
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)
  
with open("/home/lstm_train.txt", encoding='utf8') as f:
    train_data = []
    word = []
    label = []
    data = f.readline().strip()
    while data:
        data = data.strip()
        SP = data.split(' ')
        if len(SP) == 2:
            word.append(SP[0])
            label.append(SP[1])
        else:
            if len(word) == 100 and 'I-PRO' in label:
                train_data.append((word, label))
            word = []
            label = []
        data = f.readline()
 
word_to_ix = {}
for sent, _ in train_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
 
tag_to_ix = {"O": 0, "I-PRO": 1}
for i in range(len(train_data)):
    train_data[i] = ([word_to_ix[t] for t in train_data[i][0]], [tag_to_ix[t] for t in train_data[i][1]])
 
# 词向量的维度
EMBEDDING_DIM = 128
 
# 隐藏层的单元数
HIDDEN_DIM = 128
 
# 批大小
batch_size = 10  
class LSTMTagger(nn.Module):
 
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, batch_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
 
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
 
        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
 
    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        # input_tensor = embeds.view(self.batch_size, len(sentence) // self.batch_size, -1)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores
 
    def predict(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores 
 
loss_function = nn.NLLLoss()
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), batch_size)
optimizer = optim.SGD(model.parameters(), lr=0.1)
 
data_set_word = []
data_set_label = []
for data_tuple in train_data:
    data_set_word.append(data_tuple[0])
    data_set_label.append(data_tuple[1])
torch_dataset = Data.TensorDataset(torch.tensor(data_set_word, dtype=torch.long), torch.tensor(data_set_label, dtype=torch.long))
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,  # torch TensorDataset format
    batch_size=batch_size,  # mini batch size
    shuffle=True,  #
    num_workers=2,  # 多线程来读数据
)
 
# 训练过程
for epoch in range(200):
    for step, (batch_x, batch_y) in enumerate(loader):
        # 梯度清零
        model.zero_grad()
        tag_scores = model(batch_x)
 
        # 计算损失
        tag_scores = tag_scores.view(-1, tag_scores.shape[2])
        batch_y = batch_y.view(batch_y.shape[0]*batch_y.shape[1])
        loss = loss_function(tag_scores, batch_y)
        print(loss)
        # 后向传播
        loss.backward()
 
        # 更新参数
        optimizer.step()
 
# 测试过程
with torch.no_grad():
    inputs = torch.tensor([data_set_word[0]], dtype=torch.long)
    print(inputs)
    tag_scores = model.predict(inputs)
    print(tag_scores.shape)
    print(torch.argmax(tag_scores, dim=2))

补充:PyTorch基础-使用LSTM神经网络实现手写数据集识别

看代码吧~

import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置
                            train = True, # 载入训练集
                            transform=transforms.ToTensor(), # 把数据变成tensor类型
                            download = True # 下载
                           )
# 测试集
test_data = datasets.MNIST(root="./",
                            train = False,
                            transform=transforms.ToTensor(),
                            download = True
                           )
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):
    inputs,labels = data
    print(inputs.shape)
    print(labels.shape)
    break
# 定义网络结构
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM,self).__init__()# 初始化
        self.lstm = torch.nn.LSTM(
            input_size = 28, # 表示输入特征的大小
            hidden_size = 64, # 表示lstm模块的数量
            num_layers = 1, # 表示lstm隐藏层的层数
            batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature)
        )
        self.out = torch.nn.Linear(in_features=64,out_features=10)
        self.softmax = torch.nn.Softmax(dim=1)
    def forward(self,x):
        # (batch,seq_len,feature)
        x = x.view(-1,28,28)
        # output:(batch,seq_len,hidden_size)包含每个序列的输出结果
        # 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers
        # h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        # c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        output,(h_n,c_n) = self.lstm(x)
        output_in_last_timestep = h_n[-1,:,:]
        x = self.out(output_in_last_timestep)
        x = self.softmax(x)
        return x
# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
# 定义模型训练和测试的方法
def train():
    # 模型的训练状态
    model.train()
    for i,data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 交叉熵代价函数out(batch,C:类别的数量),labels(batch)
        loss = mse_loss(out,labels)
        # 梯度清零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()
        
def test():
    # 模型的测试状态
    model.eval()
    correct = 0 # 测试集准确率
    for i,data in enumerate(test_loader):
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 获得最大值,以及最大值所在的位置
        _,predicted = torch.max(out,1)
        # 预测正确的数量
        correct += (predicted==labels).sum()
    print("Test acc:{0}".format(correct.item()/len(test_data)))
    
    correct = 0
    for i,data in enumerate(train_loader): # 训练集准确率
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 获得最大值,以及最大值所在的位置
        _,predicted = torch.max(out,1)
        # 预测正确的数量
        correct += (predicted==labels).sum()
    print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(10):
    print("epoch:",epoch)
    train()
    test()

pytorch 如何使用batch训练lstm网络

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python安装mysql-python简明笔记(ubuntu环境)
Jun 25 Python
Python守护线程用法实例
Jun 23 Python
Windows平台Python连接sqlite3数据库的方法分析
Jul 12 Python
Python2.7下安装Scrapy框架步骤教程
Dec 22 Python
Python决策树分类算法学习
Dec 22 Python
对python numpy数组中冒号的使用方法详解
Apr 17 Python
详解Pytorch 使用Pytorch拟合多项式(多项式回归)
May 24 Python
解决python打不开文件(文件不存在)的问题
Feb 18 Python
Python实现判断一个整数是否为回文数算法示例
Mar 02 Python
Python叠加两幅栅格图像的实现方法
Jul 05 Python
Python全局变量与global关键字常见错误解决方案
Oct 05 Python
flask框架中的cookie和session使用
Jan 31 Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
You might like
AMFPHP php远程调用(RPC, Remote Procedure Call)工具 快速入门教程
2010/05/10 PHP
PHP调用C#开发的dll类库方法
2014/07/28 PHP
实现WordPress主题侧边栏切换功能的PHP脚本详解
2015/12/14 PHP
php结合web uploader插件实现分片上传文件
2016/05/10 PHP
简单理解PHP的面向对象编程方式
2016/05/17 PHP
phpstorm激活码2020附使用详细教程
2020/09/25 PHP
优化网页之快速的呈现我们的网页
2007/06/29 Javascript
JavaScript 学习笔记(十五)
2010/01/28 Javascript
javascript继承之为什么要继承
2012/11/10 Javascript
js打开windows上的可执行文件示例
2014/05/27 Javascript
AngularJS基础学习笔记之指令
2015/05/10 Javascript
JS实现光滑展开合拢的菜单效果代码
2015/09/16 Javascript
AngularJs ng-route路由详解及实例代码
2016/09/14 Javascript
AngularJS中一般函数参数传递用法分析
2016/11/22 Javascript
纯js代码生成可搜索选择下拉列表的实例
2018/01/11 Javascript
vue-router+nginx 非根路径配置方法
2018/06/30 Javascript
vue 纯js监听滚动条到底部的实例讲解
2018/09/03 Javascript
详解webpack打包nodejs项目(前端代码)
2018/09/19 NodeJs
vue中@change兼容问题详解
2019/10/25 Javascript
vue组件讲解(is属性的用法)模板标签替换操作
2020/09/04 Javascript
在Gnumeric下使用Python脚本操作表格的教程
2015/04/14 Python
Python随手笔记之标准类型内建函数
2015/12/02 Python
梯度下降法介绍及利用Python实现的方法示例
2017/07/12 Python
Python微信企业号开发之回调模式接收微信端客户端发送消息及被动返回消息示例
2017/08/21 Python
tf.truncated_normal与tf.random_normal的详细用法
2018/03/05 Python
解决Django migrate No changes detected 不能创建表的问题
2018/05/27 Python
python解析含有重复key的json方法
2019/01/22 Python
python爬取微信公众号文章的方法
2019/02/26 Python
详解用Python为直方图绘制拟合曲线的两种方法
2019/08/21 Python
Django 自动生成api接口文档教程
2019/11/19 Python
AmazeUI 缩略图的实现示例
2020/08/18 HTML / CSS
美国大尺码女装零售商:TORRID
2016/10/01 全球购物
美国性感内衣店:Yandy
2018/06/12 全球购物
平面设计专业大学生职业规划书
2014/03/12 职场文书
预备党员入党感想
2015/08/10 职场文书
解决SpringBoot跨域的三种方式
2021/06/26 Java/Android