pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之大话题小函数(1)
Oct 10 Python
更改Python命令行交互提示符的方法
Jan 14 Python
如何高效使用Python字典的方法详解
Aug 31 Python
如何用Python实现简单的Markdown转换器
Jul 16 Python
windows下安装Python虚拟环境virtualenvwrapper-win
Jun 14 Python
简单了解python变量的作用域
Jul 30 Python
Python pip 安装与使用(安装、更新、删除)
Oct 06 Python
opencv3/C++实现视频背景去除建模(BSM)
Dec 11 Python
python matplotlib包图像配色方案分享
Mar 14 Python
python中的selenium安装的步骤(浏览器自动化测试框架)
Mar 17 Python
django 外键创建注意事项说明
May 20 Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
根据中文裁减字符串函数的php代码
2013/12/03 PHP
PHP静态成员变量和非静态成员变量详解
2017/02/14 PHP
PHP 实现人民币小写转换成大写的方法及大小写转换函数
2017/11/17 PHP
几个比较经典常用的jQuery小技巧
2010/03/01 Javascript
基于JavaScript实现 获取鼠标点击位置坐标的方法
2013/04/12 Javascript
探讨JavaScript中声明全局变量三种方式的异同
2013/12/03 Javascript
javascript实现动态模态绑定grid过程代码
2014/09/22 Javascript
jQuery动画出现连续触发、滞后反复执行的解决方法
2015/01/28 Javascript
自定义百度分享的分享按钮
2015/03/18 Javascript
jQuery的deferred对象使用详解
2016/09/25 Javascript
浅析jQuery操作select控件的取值和设值
2016/12/07 Javascript
详解vue事件对象、冒泡、阻止默认行为
2017/03/20 Javascript
基于vue实现swipe轮播组件实例代码
2017/05/24 Javascript
js数组去重的方法总结
2019/01/18 Javascript
ES6小技巧之代替lodash
2019/06/07 Javascript
微信小程序调用后台service教程详解
2020/11/06 Javascript
[05:15]DOTA2英雄梦之声_第16期_灰烬之灵
2014/06/21 DOTA
Python使用scrapy采集数据过程中放回下载过大页面的方法
2015/04/08 Python
使用Python对Excel进行读写操作
2017/03/30 Python
Python中__slots__属性介绍与基本使用方法
2018/09/05 Python
selenium设置proxy、headers的方法(phantomjs、Chrome、Firefox)
2018/11/29 Python
用vue.js组件模拟v-model指令实例方法
2019/07/05 Python
Python 获取命令行参数内容及参数个数的实例
2019/12/20 Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
2020/01/02 Python
阿迪达斯德国官方网站:adidas德国
2017/07/12 全球购物
应届本科生推荐信范文
2013/12/25 职场文书
奶茶店创业计划书
2014/08/14 职场文书
爱岗敬业事迹材料
2014/12/24 职场文书
毕业生自荐材料范文
2014/12/30 职场文书
产品调价通知函
2015/04/20 职场文书
2016学习雷锋精神活动倡议书
2015/04/27 职场文书
公司晚会主持词
2019/04/17 职场文书
2019年销售部季度工作计划3篇
2019/10/09 职场文书
创业计划书之孕婴生活馆
2019/11/11 职场文书
MySQL定时备份数据库(全库备份)的实现
2021/09/25 MySQL
vue ref如何获取子组件属性值
2022/03/31 Vue.js