pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的sort方法使用详解
Jul 25 Python
列举Python中吸引人的一些特性
Apr 09 Python
python中subprocess批量执行linux命令
Apr 27 Python
Flask框架使用DBUtils模块连接数据库操作示例
Jul 20 Python
python简单实现矩阵的乘,加,转置和逆运算示例
Jul 10 Python
python破解bilibili滑动验证码登录功能
Sep 11 Python
python词云库wordCloud使用方法详解(解决中文乱码)
Feb 17 Python
如何利用python之wxpy模块玩转微信
Aug 17 Python
通过代码实例了解Python异常本质
Sep 16 Python
pyqt5实现井字棋的示例代码
Dec 07 Python
Python数据分析库pandas高级接口dt的使用详解
Dec 11 Python
python 利用jieba.analyse进行 关键词提取
Dec 17 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
Trying to clone an uncloneable object of class Imagic的解决方法
2012/01/11 PHP
destoon实现VIP排名一直在前面排序的方法
2014/08/21 PHP
PHP中CheckBox多选框上传失败的代码写法
2017/02/13 PHP
thinkPHP5.0框架简单配置作用域的方法
2017/03/17 PHP
Gambit vs ForZe BO3 第三场 2.13
2021/03/10 DOTA
可拖动窗口,附带鼠标控制渐变透明,开启关闭功能
2006/06/26 Javascript
extjs关于treePanel+chekBox全部选中以及清空选中问题探讨
2013/04/02 Javascript
jquery中focus()函数实现当对象获得焦点后自动把光标移到内容最后
2013/09/29 Javascript
jquery对ajax的支持介绍
2013/12/10 Javascript
JS阻止冒泡事件以及默认事件发生的简单方法
2014/01/17 Javascript
jQuery模拟点击A标记示例参考
2014/04/17 Javascript
jQuery插件分享之分页插件jqPagination
2014/06/06 Javascript
Node.js文件操作方法汇总
2016/03/22 Javascript
jQuery验证插件validate使用详解
2016/05/11 Javascript
BootStrap轻松实现微信页面开发代码分享
2016/10/21 Javascript
详解js前端代码异常监控
2017/01/11 Javascript
详谈js遍历集合(Array,Map,Set)
2017/04/06 Javascript
vue+iview写个弹框的示例代码
2017/12/05 Javascript
JS Object.preventExtensions(),Object.seal()与Object.freeze()用法实例分析
2018/08/25 Javascript
vue头部导航动态点击处理方法
2018/11/02 Javascript
解决layui 三级联动下拉框更新时回显的问题
2019/09/03 Javascript
Javascript柯里化实现原理及作用解析
2020/10/22 Javascript
[01:38:19]夜魇凡尔赛茶话会 第五期
2021/03/11 DOTA
Python自动化测试ConfigParser模块读写配置文件
2016/08/15 Python
在 Python 应用中使用 MongoDB的方法
2017/01/05 Python
django启动uwsgi报错的解决方法
2018/04/08 Python
Python定时发送消息的脚本:每天跟你女朋友说晚安
2018/10/21 Python
最小二乘法及其python实现详解
2020/02/24 Python
浅谈keras中的keras.utils.to_categorical用法
2020/07/02 Python
Python Merge函数原理及用法解析
2020/09/16 Python
纯CSS3实现鼠标悬停提示气泡效果
2014/02/28 HTML / CSS
爱尔兰电脑、家电和家具购物网站:Buy It Direct
2019/07/09 全球购物
给排水专业应届生求职信
2013/10/12 职场文书
化学学院毕业生自荐信范文
2013/12/17 职场文书
家长会欢迎词
2015/01/23 职场文书
css实现两栏布局,左侧固定宽,右侧自适应的多种方法
2021/08/07 HTML / CSS