pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现dict版图遍历示例
Feb 19 Python
Python实现模拟登录及表单提交的方法
Jul 25 Python
python的socket编程入门
Jan 29 Python
flask session组件的使用示例
Dec 25 Python
Python中super函数用法实例分析
Mar 18 Python
python函数的万能参数传参详解
Jul 26 Python
Python3 翻转二叉树的实现
Sep 30 Python
Python 列表的清空方式
Jan 13 Python
python如何提取英语pdf内容并翻译
Mar 03 Python
自定义Django默认的sitemap站点地图样式
Mar 04 Python
Python如何输出整数
Jun 07 Python
解决python3输入的坑——input()
Dec 05 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
PHP 函数语法介绍一
2009/06/14 PHP
php中Y2K38的漏洞解决方法实例分析
2014/09/22 PHP
JS查看对象功能代码
2008/04/25 Javascript
IE6/7/8中Option元素未设value时Select将获取空字符串
2011/04/07 Javascript
javascript写的简单的计算器,内容很多,方法实用,推荐
2011/12/29 Javascript
jQuery之尺寸调整组件的深入解析
2013/06/19 Javascript
JavaScript不刷新实现浏览器的前进后退功能
2014/11/05 Javascript
jQuery简单实现隐藏以及显示特效
2015/02/26 Javascript
JQuery中绑定事件(bind())和移除事件(unbind())
2015/02/27 Javascript
js+html5获取用户地理位置信息并在Google地图上显示的方法
2015/06/05 Javascript
JavaScript继承模式粗探
2016/01/12 Javascript
Angular和Vue双向数据绑定的实现原理(重点是vue的双向绑定)
2016/11/22 Javascript
Vue原理剖析 实现双向绑定MVVM
2017/05/03 Javascript
ECMAscript 变量作用域总结概括
2017/08/18 Javascript
vue 右键菜单插件 简单、可扩展、样式自定义的右键菜单
2018/11/29 Javascript
Vue动态路由缓存不相互影响的解决办法
2019/02/19 Javascript
JS中的算法与数据结构之字典(Dictionary)实例详解
2019/08/20 Javascript
webpack 最佳配置指北(推荐)
2020/01/07 Javascript
vue中使用带隐藏文本信息的图片、图片水印的方法
2020/04/24 Javascript
[08:08]DOTA2-DPC中国联赛2月28日Recap集锦
2021/03/11 DOTA
使用Python生成XML的方法实例
2017/03/21 Python
numpy自动生成数组详解
2017/12/15 Python
详解python-图像处理(映射变换)
2019/03/22 Python
使用卷积神经网络(CNN)做人脸识别的示例代码
2020/03/27 Python
Python 中如何使用 virtualenv 管理虚拟环境
2021/01/21 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
2021/01/27 Python
基于PyInstaller各参数的含义说明
2021/03/04 Python
英国户外服装品牌:Craghoppers
2019/04/25 全球购物
商场端午节活动方案
2014/01/29 职场文书
医学专业毕业生推荐信
2014/07/12 职场文书
教师个人总结范文
2015/02/11 职场文书
大学优秀学生主要事迹材料
2015/11/04 职场文书
庭外和解协议书
2016/03/23 职场文书
Nginx配置之实现多台服务器负载均衡
2021/08/02 Servers
SQL SERVER存储过程用法详解
2022/02/24 SQL Server
Spring Cloud Netflix 套件中的负载均衡组件 Ribbon
2022/04/13 Java/Android