pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python读取mp3中ID3信息的方法
Mar 05 Python
Python中处理字符串之endswith()方法的使用简介
May 18 Python
python关键字and和or用法实例
May 28 Python
Python 通过URL打开图片实例详解
Jun 01 Python
Python基础之getpass模块详细介绍
Aug 10 Python
Python金融数据可视化汇总
Nov 17 Python
Python 实现网页自动截图的示例讲解
May 17 Python
PyTorch线性回归和逻辑回归实战示例
May 22 Python
浅谈pycharm的xmx和xms设置方法
Dec 03 Python
python 计算一个字符串中所有数字的和实例
Jun 11 Python
python 一篇文章搞懂装饰器所有用法(建议收藏)
Aug 23 Python
用python3读取python2的pickle数据方式
Dec 25 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
深入解析Session是否必须依赖Cookie
2013/08/02 PHP
PHP使用两个栈实现队列功能的方法
2018/01/15 PHP
php使用fputcsv实现大数据的导出操作详解
2020/02/27 PHP
jQuery.extend 函数详解
2012/02/03 Javascript
Javascript事件实例详解
2013/11/06 Javascript
jquery对单选框,多选框,文本框等常见操作小结
2014/01/08 Javascript
通过隐藏iframe实现无刷新上传文件操作
2016/03/16 Javascript
简述jQuery Easyui一些用法
2017/08/01 jQuery
jQuery选择器中的特殊符号处理方法
2017/09/08 jQuery
Vue 2.0学习笔记之Vue中的computed属性
2017/10/16 Javascript
JS开发 富文本编辑器TinyMCE详解
2019/07/19 Javascript
[03:49]显微镜下的DOTA2第十五期—VG登基之路完美团
2014/06/24 DOTA
[02:41]DOTA2亚洲邀请赛小组赛第三日 赛事回顾
2015/02/01 DOTA
python正常时间和unix时间戳相互转换的方法
2015/04/23 Python
Python max内置函数详细介绍
2016/11/17 Python
python3实现爬取淘宝美食代码分享
2018/09/23 Python
python实现飞机大战游戏
2020/10/26 Python
python 实现在一张图中绘制一个小的子图方法
2019/07/07 Python
python与mysql数据库交互的实现
2020/01/06 Python
Python识别html主要文本框过程解析
2020/02/18 Python
详解pyqt5的UI中嵌入matplotlib图形并实时刷新(挖坑和填坑)
2020/08/07 Python
python 如何调用 dubbo 接口
2020/09/24 Python
MADE法国:提供原创设计师家具
2018/09/18 全球购物
印度购买眼镜和太阳镜网站:Coolwinks
2018/09/26 全球购物
国外软件测试工程师面试题
2016/12/09 面试题
AJAX都有哪些有点和缺点
2012/11/03 面试题
汽车专业大学生职业生涯规划范文
2014/01/07 职场文书
党员批评与自我批评思想汇报(集锦)
2014/09/14 职场文书
2014国庆节餐厅促销活动策划方案
2014/09/16 职场文书
会计实训报告范文
2014/11/04 职场文书
小学教师求职信范文
2015/03/20 职场文书
向雷锋同志学习倡议书
2015/04/27 职场文书
幼儿园家长反馈意见
2015/06/03 职场文书
2016元旦主持人开场白
2015/12/03 职场文书
Python实现生成bmp图像的方法
2021/06/13 Python
springboot应用服务启动事件的监听实现
2022/04/06 Java/Android