pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python常用随机数与随机字符串方法实例
Apr 09 Python
在Python中用keys()方法返回字典键的教程
May 21 Python
代码分析Python地图坐标转换
Feb 08 Python
WxPython建立批量录入框窗口
Feb 27 Python
python2.7使用plotly绘制本地散点图和折线图
Apr 02 Python
django-crontab 定时执行任务方法的实现
Sep 06 Python
浅析python redis的连接及相关操作
Nov 07 Python
利用 Python ElementTree 生成 xml的实例
Mar 06 Python
使用IPython或Spyder将省略号表示的内容完整输出
Apr 20 Python
Pytorch转tflite方式
May 25 Python
PyTorch device与cuda.device用法
Apr 03 Python
python自动获取微信公众号最新文章的实现代码
Jul 15 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
php对数组排序代码分享
2014/02/24 PHP
查找php配置文件php.ini所在路径的二种方法
2014/05/26 PHP
php上传图片并压缩的实现方法
2015/12/22 PHP
PHP的Yii框架中YiiBase入口类的扩展写法示例
2016/03/17 PHP
Yii2前后台分离及migrate使用(七)
2016/05/04 PHP
php socket通信简单实现
2016/11/18 PHP
js图片卷帘门导航菜单特效代码分享
2015/09/10 Javascript
jQuery实用技巧必备(上)
2015/11/02 Javascript
基于JS实现的笛卡尔乘积之商品发布
2016/05/13 Javascript
javascript中Number的方法小结
2016/11/21 Javascript
Bootstrap基本组件学习笔记之缩略图(13)
2016/12/08 Javascript
js 获取图像缩放后的实际宽高,位置等信息
2017/03/07 Javascript
JS完成画圆圈的小球
2017/03/07 Javascript
微信小程序之圆形进度条实现思路
2018/02/22 Javascript
vue实现裁切图片同时实现放大、缩小、旋转功能
2018/03/02 Javascript
Vuejs2 + Webpack框架里,模拟下载的实例讲解
2018/09/05 Javascript
详解如何提升JSON.stringify()的性能
2019/06/12 Javascript
Vue.js递归组件实现组织架构树和选人功能
2019/07/04 Javascript
在vue中利用全局路由钩子给url统一添加公共参数的例子
2019/11/01 Javascript
JS数组方法reduce的用法实例分析
2020/03/03 Javascript
vue cli 3.0通用打包配置代码,不分一二级目录
2020/09/02 Javascript
原生js实现移动小球(碰撞检测)
2020/12/17 Javascript
Python中 map()函数的用法详解
2018/07/10 Python
python内存监控工具memory_profiler和guppy的用法详解
2019/07/29 Python
对django 模型 unique together的示例讲解
2019/08/06 Python
python误差棒图errorbar()函数实例解析
2020/02/11 Python
CSS3实战第一波 让我们尽情的圆角吧
2010/08/27 HTML / CSS
HTML5 history新特性pushState、replaceState及两者的区别
2015/12/26 HTML / CSS
比利时香水网上商店:NOTINO
2018/03/28 全球购物
英国IT硬件供应商,定制游戏PC:Mesh Computers
2019/03/28 全球购物
《珍珠泉》教学反思
2014/02/20 职场文书
幼儿园户外活动总结
2014/07/04 职场文书
乡镇个人对照检查材料
2014/08/22 职场文书
2014年统计工作总结
2014/11/21 职场文书
2014年科技工作总结
2014/11/26 职场文书
Python+DeOldify实现老照片上色功能
2022/06/21 Python