pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
开始着手第一个Django项目
Jul 15 Python
Python中利用Scipy包的SIFT方法进行图片识别的实例教程
Jun 03 Python
Python实现通过文件路径获取文件hash值的方法
Apr 29 Python
python2.7安装图文教程
Mar 13 Python
Python利用splinter实现浏览器自动化操作方法
May 11 Python
详解从Django Rest Framework响应中删除空字段
Jan 11 Python
基于python实现KNN分类算法
Apr 23 Python
python sklearn库实现简单逻辑回归的实例代码
Jul 01 Python
Python3 获取文件属性的方式(时间、大小等)
Mar 12 Python
python中如何进行连乘计算
May 28 Python
利用python控制Autocad:pyautocad方式
Jun 01 Python
python Protobuf定义消息类型知识点讲解
Mar 02 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
Mysql的常用命令
2006/10/09 PHP
自己前几天写的无限分类类
2007/02/14 PHP
php5.3 goto函数介绍和示例
2014/03/21 PHP
PHP 面向对象程序设计(oop)学习笔记 (二) - 静态变量的属性和方法及延迟绑定
2014/06/12 PHP
yii2简单使用less代替css示例
2017/03/10 PHP
在JavaScript中遭遇级联表达式陷阱
2007/03/08 Javascript
javascript 字符串连接的性能问题(多浏览器)
2008/11/18 Javascript
关于javascript 回调函数中变量作用域的讨论
2009/09/11 Javascript
javascript面向对象编程(一) 实例代码
2010/06/25 Javascript
javaScript 利用闭包模拟对象的私有属性
2011/12/29 Javascript
jQuery focus和blur事件的应用详解
2014/01/26 Javascript
jQuery中fadeOut()方法用法实例
2014/12/24 Javascript
深入解析JavaScript中的arguments对象
2016/06/12 Javascript
jQuery编写网页版2048小游戏
2017/01/06 Javascript
Vue-cli3.X使用px2 rem遇到的问题及解决方法
2019/08/08 Javascript
基于ajax实现上传图片代码示例解析
2020/12/03 Javascript
详解微信小程序轨迹回放实现及遇到的坑
2021/02/02 Javascript
python使用正则表达式检测密码强度源码分享
2014/06/11 Python
python with statement 进行文件操作指南
2014/08/22 Python
python3简单实现微信爬虫
2015/04/09 Python
python实现将内容分行输出
2015/11/05 Python
tensorflow 加载部分变量的实例讲解
2018/07/27 Python
详解Python locals()的陷阱
2019/03/26 Python
Python学习笔记之自定义函数用法详解
2019/06/08 Python
python中字符串数组逆序排列方法总结
2019/06/23 Python
Python基于callable函数检测对象是否可被调用
2020/10/16 Python
python 实现客户端与服务端的通信
2020/12/23 Python
极简的HTML5模版
2015/07/09 HTML / CSS
如果NULL定义成#define NULL((char *)0)难道不就可以向函数传入不加转换的NULL了吗
2012/02/15 面试题
大专生自我鉴定范文
2013/10/01 职场文书
个人安全承诺书
2014/05/22 职场文书
关于读书的活动方案
2014/08/14 职场文书
纪念九一八事变演讲稿:勿忘国耻
2014/09/14 职场文书
工作计划范文之财务管理
2019/08/09 职场文书
MongoDB数据库部署环境准备及使用介绍
2022/03/21 MongoDB
Win11电脑显示本地时间与服务器时间不一致怎么解决?
2022/04/05 数码科技