pytorch 如何使用batch训练lstm网络


Posted in Python onMay 28, 2021

batch的lstm

# 导入相应的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data 
torch.manual_seed(1) 
 
# 准备数据的阶段
def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)
  
with open("/home/lstm_train.txt", encoding='utf8') as f:
    train_data = []
    word = []
    label = []
    data = f.readline().strip()
    while data:
        data = data.strip()
        SP = data.split(' ')
        if len(SP) == 2:
            word.append(SP[0])
            label.append(SP[1])
        else:
            if len(word) == 100 and 'I-PRO' in label:
                train_data.append((word, label))
            word = []
            label = []
        data = f.readline()
 
word_to_ix = {}
for sent, _ in train_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
 
tag_to_ix = {"O": 0, "I-PRO": 1}
for i in range(len(train_data)):
    train_data[i] = ([word_to_ix[t] for t in train_data[i][0]], [tag_to_ix[t] for t in train_data[i][1]])
 
# 词向量的维度
EMBEDDING_DIM = 128
 
# 隐藏层的单元数
HIDDEN_DIM = 128
 
# 批大小
batch_size = 10  
class LSTMTagger(nn.Module):
 
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, batch_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
 
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
 
        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
 
    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        # input_tensor = embeds.view(self.batch_size, len(sentence) // self.batch_size, -1)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores
 
    def predict(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds)
        tag_space = self.hidden2tag(lstm_out)
        scores = F.log_softmax(tag_space, dim=2)
        return scores 
 
loss_function = nn.NLLLoss()
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), batch_size)
optimizer = optim.SGD(model.parameters(), lr=0.1)
 
data_set_word = []
data_set_label = []
for data_tuple in train_data:
    data_set_word.append(data_tuple[0])
    data_set_label.append(data_tuple[1])
torch_dataset = Data.TensorDataset(torch.tensor(data_set_word, dtype=torch.long), torch.tensor(data_set_label, dtype=torch.long))
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
    dataset=torch_dataset,  # torch TensorDataset format
    batch_size=batch_size,  # mini batch size
    shuffle=True,  #
    num_workers=2,  # 多线程来读数据
)
 
# 训练过程
for epoch in range(200):
    for step, (batch_x, batch_y) in enumerate(loader):
        # 梯度清零
        model.zero_grad()
        tag_scores = model(batch_x)
 
        # 计算损失
        tag_scores = tag_scores.view(-1, tag_scores.shape[2])
        batch_y = batch_y.view(batch_y.shape[0]*batch_y.shape[1])
        loss = loss_function(tag_scores, batch_y)
        print(loss)
        # 后向传播
        loss.backward()
 
        # 更新参数
        optimizer.step()
 
# 测试过程
with torch.no_grad():
    inputs = torch.tensor([data_set_word[0]], dtype=torch.long)
    print(inputs)
    tag_scores = model.predict(inputs)
    print(tag_scores.shape)
    print(torch.argmax(tag_scores, dim=2))

补充:PyTorch基础-使用LSTM神经网络实现手写数据集识别

看代码吧~

import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置
                            train = True, # 载入训练集
                            transform=transforms.ToTensor(), # 把数据变成tensor类型
                            download = True # 下载
                           )
# 测试集
test_data = datasets.MNIST(root="./",
                            train = False,
                            transform=transforms.ToTensor(),
                            download = True
                           )
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):
    inputs,labels = data
    print(inputs.shape)
    print(labels.shape)
    break
# 定义网络结构
class LSTM(nn.Module):
    def __init__(self):
        super(LSTM,self).__init__()# 初始化
        self.lstm = torch.nn.LSTM(
            input_size = 28, # 表示输入特征的大小
            hidden_size = 64, # 表示lstm模块的数量
            num_layers = 1, # 表示lstm隐藏层的层数
            batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature)
        )
        self.out = torch.nn.Linear(in_features=64,out_features=10)
        self.softmax = torch.nn.Softmax(dim=1)
    def forward(self,x):
        # (batch,seq_len,feature)
        x = x.view(-1,28,28)
        # output:(batch,seq_len,hidden_size)包含每个序列的输出结果
        # 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers
        # h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        # c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
        output,(h_n,c_n) = self.lstm(x)
        output_in_last_timestep = h_n[-1,:,:]
        x = self.out(output_in_last_timestep)
        x = self.softmax(x)
        return x
# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
# 定义模型训练和测试的方法
def train():
    # 模型的训练状态
    model.train()
    for i,data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 交叉熵代价函数out(batch,C:类别的数量),labels(batch)
        loss = mse_loss(out,labels)
        # 梯度清零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()
        
def test():
    # 模型的测试状态
    model.eval()
    correct = 0 # 测试集准确率
    for i,data in enumerate(test_loader):
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 获得最大值,以及最大值所在的位置
        _,predicted = torch.max(out,1)
        # 预测正确的数量
        correct += (predicted==labels).sum()
    print("Test acc:{0}".format(correct.item()/len(test_data)))
    
    correct = 0
    for i,data in enumerate(train_loader): # 训练集准确率
        # 获得一个批次的数据和标签
        inputs,labels = data
        # 获得模型预测结果(64,10)
        out = model(inputs)
        # 获得最大值,以及最大值所在的位置
        _,predicted = torch.max(out,1)
        # 预测正确的数量
        correct += (predicted==labels).sum()
    print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(10):
    print("epoch:",epoch)
    train()
    test()

pytorch 如何使用batch训练lstm网络

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的加密模块md5、sha、crypt使用实例
Sep 28 Python
python求解水仙花数的方法
May 11 Python
Python使用django搭建web开发环境
Jun 09 Python
Python有序查找算法之二分法实例分析
Dec 11 Python
python 用for循环实现1~n求和的实例
Feb 01 Python
python实现LBP方法提取图像纹理特征实现分类的步骤
Jul 11 Python
使用pandas读取文件的实现
Jul 31 Python
python脚本实现mp4中的音频提取并保存在原目录
Feb 27 Python
Python类super()及私有属性原理解析
Jun 15 Python
使用tensorflow进行音乐类型的分类
Aug 14 Python
如何基于Python pygame实现动画跑马灯
Nov 18 Python
关于python类SortedList详解
Sep 04 Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
解决Pytorch dataloader时报错每个tensor维度不一样的问题
May 28 #Python
You might like
php中全局变量global的使用演示代码
2011/05/18 PHP
JoshChen_php新手进阶高手不可或缺的规范介绍
2013/08/16 PHP
php动态生成版权所有信息的方法
2015/03/24 PHP
php将html转成wml的WAP标记语言实例
2015/07/08 PHP
php实现微信模板消息推送
2018/03/30 PHP
php遍历目录下文件并按修改时间排序操作示例
2019/07/12 PHP
PHP如何开启Opcache功能提升程序处理效率
2020/04/27 PHP
fix-ie5.js扩展在IE5下不能使用的几个方法
2007/08/20 Javascript
JQUERY CHECKBOX全选,取消全选,反选方法三
2008/08/30 Javascript
Dojo 学习笔记入门篇 First Dojo Example
2009/11/15 Javascript
jquery 简短右键菜单 多浏览器兼容
2010/01/01 Javascript
js+html+css实现鼠标移动div实例
2013/01/30 Javascript
Javascript 鼠标移动上去小三角形滑块缓慢跟随效果
2013/04/26 Javascript
微信小程序 开发指南详解
2016/09/27 Javascript
利用纯JS实现像素逐渐显示的方法示例
2017/08/14 Javascript
详解创建自定义的Angular Schematics
2018/06/06 Javascript
Vue安装浏览器开发工具的步骤详解
2019/05/12 Javascript
Layui事件监听的实现(表单和数据表格)
2019/10/17 Javascript
vue实现的多页面项目如何优化打包的步骤详解
2020/07/19 Javascript
Vue+Element ui 根据后台返回数据设置动态表头操作
2020/09/21 Javascript
[05:43]VG.R战队教练Mikasa专访:为目标从未停止战斗
2016/08/02 DOTA
Python多进程编程技术实例分析
2014/09/16 Python
Python实现八大排序算法
2016/08/13 Python
Python cookbook(字符串与文本)针对任意多的分隔符拆分字符串操作示例
2018/04/19 Python
基于python的socket实现单机五子棋到双人对战
2020/03/24 Python
python多线程并发实例及其优化
2019/06/27 Python
Python学习笔记之错误和异常及访问错误消息详解
2019/08/08 Python
使用OpenCV获取图像某点的颜色值,并设置某点的颜色
2020/06/02 Python
Python 解析xml文件的示例
2020/09/29 Python
python时间time模块处理大全
2020/10/25 Python
英国第一的购买便宜玩具和游戏的在线购物网站:Bargain Max
2018/01/24 全球购物
汇科协同Java笔试题
2012/03/31 面试题
大二自我鉴定范文
2013/10/05 职场文书
后勤自我鉴定
2013/10/13 职场文书
见习报告的格式
2014/10/31 职场文书
Win11如何修改dns?Win11修改dns图文教程
2022/01/18 数码科技