pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 查找文件夹下所有文件 实现代码
Jul 01 Python
Flask框架的学习指南之制作简单blog系统
Nov 20 Python
Python基于最小二乘法实现曲线拟合示例
Jun 14 Python
python 实现批量xls文件转csv文件的方法
Oct 23 Python
Python脚本利用adb进行手机控制的方法
Jul 08 Python
python实现从wind导入数据
Dec 03 Python
Django Admin设置应用程序及模型顺序方法详解
Apr 01 Python
打印tensorflow恢复模型中所有变量与操作节点方式
May 26 Python
Python生成器传参数及返回值原理解析
Jul 22 Python
python多线程和多进程关系详解
Dec 14 Python
python中的unittest框架实例详解
Feb 05 Python
matplotlib之属性组合包(cycler)的使用
Feb 24 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
PHP数据过滤的方法
2013/10/30 PHP
ASP中进行HTML数据及JS数据编码函数
2009/11/11 Javascript
jquery中ajax学习笔记3
2011/10/16 Javascript
网页编辑器ckeditor和ckfinder配置步骤分享
2012/05/24 Javascript
给Flash加一个超链接(推荐使用透明层)兼容主流浏览器
2013/06/09 Javascript
详解JavaScript中setSeconds()方法的使用
2015/06/11 Javascript
JavaScript函数内部属性和函数方法实例详解
2016/03/17 Javascript
jQuery弹出层后禁用底部滚动条(移动端关闭回到原位置)
2016/08/29 Javascript
Angular5中调用第三方库及jQuery的添加的方法
2018/06/07 jQuery
echarts同一页面中四个图表切换的js数据交互方法示例
2018/07/03 Javascript
vue+element-ui动态生成多级表头的方法
2018/08/28 Javascript
vue.js实现图书管理功能
2019/09/24 Javascript
微信小程序停止其他视频播放当前视频的实例代码
2019/12/25 Javascript
VSCode搭建React Native环境
2020/05/07 Javascript
实例讲解React 组件
2020/07/07 Javascript
[01:09]2014DOTA2国际邀请赛 TI4西雅图DOTA2 中国美女coser加油助威
2014/07/20 DOTA
python实现多线程抓取知乎用户
2016/12/12 Python
Python MD5加密实例详解
2017/08/02 Python
K-近邻算法的python实现代码分享
2017/12/09 Python
python执行系统命令后获取返回值的几种方式集合
2018/05/12 Python
Python中logging.NullHandler 的使用教程
2018/11/29 Python
Python argparse模块使用方法解析
2020/02/20 Python
Matlab中plot基本用法的具体使用
2020/07/17 Python
Python logging模块进行封装实现原理解析
2020/08/07 Python
python中numpy数组与list相互转换实例方法
2021/01/29 Python
草莓网英国官网:Strawberrynet UK
2017/02/12 全球购物
英国银首饰公司:e&e Jewellery
2021/02/11 全球购物
学生感冒英文请假条
2014/02/04 职场文书
技校毕业生自荐信范文
2014/03/07 职场文书
羽毛球社团活动总结
2014/06/27 职场文书
稽核岗位职责
2015/02/10 职场文书
最美乡村教师观后感
2015/06/11 职场文书
年终工作总结范文
2019/06/20 职场文书
八年级作文之感恩
2019/11/22 职场文书
在Django中使用MQTT的方法
2021/05/10 Python
解决tk mapper 通用mapper的bug问题
2021/06/16 Java/Android