pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于python的汉字转GBK码实现代码
Feb 19 Python
Django静态资源URL STATIC_ROOT的配置方法
Nov 08 Python
Django中处理出错页面的方法
Jul 15 Python
Python利用Beautiful Soup模块修改内容方法示例
Mar 27 Python
python实现雪花飘落效果实例讲解
Jun 18 Python
Python 元组操作总结
Sep 18 Python
Python函数参数类型及排序原理总结
Dec 19 Python
python itsdangerous模块的具体使用方法
Feb 17 Python
Python在终端通过pip安装好包以后在Pycharm中依然无法使用的问题(三种解决方案)
Mar 10 Python
Python捕获异常堆栈信息的几种方法(小结)
May 18 Python
python 如何实现遗传算法
Sep 22 Python
Python 循环读取数据内存不足的解决方案
May 25 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
《PHP编程最快明白》第八讲:php启发和小结
2010/11/01 PHP
PHP中MVC模式的模板引擎开发经验分享
2011/03/23 PHP
解析php入库和出库
2013/06/25 PHP
thinkphp表单上传文件并将文件路径保存到数据库中
2016/07/28 PHP
php mysql PDO 查询操作的实例详解
2017/09/23 PHP
键盘控制事件应用教程大全
2006/11/24 Javascript
6款新颖的jQuery和CSS3进度条插件推荐
2013/03/05 Javascript
javascript学习笔记(一)基础知识
2014/09/30 Javascript
jQuery zclip插件实现跨浏览器复制功能
2015/11/02 Javascript
微信小程序 实战实例开发流程详细介绍
2017/01/05 Javascript
angular.js + require.js构建模块化单页面应用的方法步骤
2017/07/19 Javascript
基于require.js的使用(实例讲解)
2017/09/07 Javascript
实例解析ES6 Proxy使用场景介绍
2018/01/08 Javascript
JavaScript获取用户所在城市及地理位置
2018/04/21 Javascript
node.js基于socket.io快速实现一个实时通讯应用
2019/04/23 Javascript
vue实现移动端拖动排序
2020/08/21 Javascript
vue-cli 3如何使用vue-bootstrap-datetimepicker日期插件
2021/02/20 Vue.js
python使用matplotlib绘图时图例显示问题的解决
2017/04/27 Python
Python爬取视频(其实是一篇福利)过程解析
2019/08/01 Python
Python环境管理virtualenv&virtualenvwrapper的配置详解
2020/07/01 Python
使用Python项目生成所有依赖包的清单方式
2020/07/13 Python
python unichr函数知识点总结
2020/12/16 Python
关于PySnooper 永远不要使用print进行调试的问题
2021/03/04 Python
CSS3 倾斜的网页图片库实例教程
2009/11/14 HTML / CSS
今天学到的CSS最新技术(与图片背景相关)
2012/12/24 HTML / CSS
CSS3 display知识详解
2015/11/25 HTML / CSS
HTML5 Canvas中绘制矩形实例
2015/01/01 HTML / CSS
英国最大的邮寄种子和植物公司:Thompson & Morgan
2017/09/21 全球购物
世界上最好的足球商店:Unisport
2019/03/02 全球购物
大学毕业生通用自我评价
2014/01/05 职场文书
消防工作实施方案
2014/06/09 职场文书
卫生标语大全
2014/06/21 职场文书
建筑节能汇报材料
2014/08/22 职场文书
企业整改报告范文
2014/11/08 职场文书
高中体育课教学反思
2016/02/16 职场文书
Python图像处理库PIL详细使用说明
2022/04/06 Python