pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的多线程编程
Apr 09 Python
详解Python命令行解析工具Argparse
Apr 20 Python
python编程实现希尔排序
Apr 13 Python
Python实现简单的语音识别系统
Dec 13 Python
Django之路由层的实现
Sep 09 Python
Python中类似于jquery的pyquery库用法分析
Dec 02 Python
Pytorch之Variable的用法
Dec 31 Python
适合Python初学者的一些编程技巧
Feb 12 Python
大数据分析用java还是Python
Jul 06 Python
python如何变换环境
Jul 21 Python
Python tkinter制作单机五子棋游戏
Sep 14 Python
一篇文章弄懂Python关键字、标识符和变量
Jul 15 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
简单易用的计数器(数据库)
2006/10/09 PHP
生成静态页面的PHP类
2006/11/25 PHP
php中计算时间差的几种方法
2009/12/31 PHP
PHP函数按引用传递参数及函数可选参数用法示例
2018/06/04 PHP
PHP绕过open_basedir限制操作文件的方法
2018/06/10 PHP
40个新鲜出炉的jQuery 插件和免费教程[上]
2012/07/24 Javascript
jquery ajax jsonp跨域调用实例代码
2013/12/11 Javascript
JavaScript字符串对象toUpperCase方法入门实例(用于把字母转换为大写)
2014/10/17 Javascript
jquery checkbox 勾选的bug问题解决方案与分析
2014/11/13 Javascript
jQuery中removeClass()方法用法实例
2015/01/05 Javascript
Js类的静态方法与实例方法区分及jQuery拓展的两种方法
2016/06/03 Javascript
javascript稀疏数组(sparse array)和密集数组用法分析
2016/12/28 Javascript
从vue源码看props的用法
2019/01/09 Javascript
vue全局自定义指令-元素拖拽的实现代码
2019/04/14 Javascript
js事件触发操作实例分析
2019/06/21 Javascript
jquery轻量级数字动画插件countUp.js使用详解
2019/10/17 jQuery
js在HTML的三种引用方式详解
2020/08/29 Javascript
[14:21]VICI vs EG (BO3)
2018/06/07 DOTA
python实现保存网页到本地示例
2014/03/16 Python
python基础教程之简单入门说明(变量和控制语言使用方法)
2014/03/25 Python
修改Python的pyxmpp2中的主循环使其提高性能
2015/04/24 Python
python numpy数组的索引和切片的操作方法
2018/10/20 Python
Python面向对象之类和实例用法分析
2019/06/08 Python
python虚拟环境完美部署教程
2019/08/06 Python
Python实现生成密码字典的方法示例
2019/09/02 Python
Python基于codecs模块实现文件读写案例解析
2020/05/11 Python
英国50岁以上人群的交友网站:Ourtime
2018/03/28 全球购物
2014年教师培训的自我评价
2014/01/03 职场文书
餐饮营销方案
2014/02/23 职场文书
主题党日活动总结
2014/07/08 职场文书
写给女朋友的保证书
2015/05/09 职场文书
离婚律师函范本
2015/05/27 职场文书
标枪加油稿
2015/07/22 职场文书
新员工入职感想
2015/08/07 职场文书
浅谈如何写好演讲稿?
2019/06/12 职场文书
教师学期述职自我鉴定
2019/08/16 职场文书