pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
读写json中文ASCII乱码问题的解决方法
Nov 05 Python
Python读取指定目录下指定后缀文件并保存为docx
Apr 23 Python
Python实现购物程序思路及代码
Jul 24 Python
python执行使用shell命令方法分享
Nov 08 Python
Python iter()函数用法实例分析
Mar 17 Python
Python基于分析Ajax请求实现抓取今日头条街拍图集功能示例
Jul 19 Python
Pyqt5实现英文学习词典
Jun 24 Python
浅析pandas 数据结构中的DataFrame
Oct 12 Python
Python通过文本和图片生成词云图
May 21 Python
详解Django自定义图片和文件上传路径(upload_to)的2种方式
Dec 01 Python
详细介绍python操作RabbitMq
Apr 12 Python
python和C/C++混合编程之使用ctypes调用 C/C++的dll
Apr 29 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
PHP安装攻略:常见问题解答(二)
2006/10/09 PHP
修改apache配置文件去除thinkphp url中的index.php
2014/01/17 PHP
php实现的PDO异常处理操作分析
2018/12/27 PHP
PHP设计模式(八)装饰器模式Decorator实例详解【结构型】
2020/05/02 PHP
传递参数的标准方法(jQuery.ajax)
2008/11/19 Javascript
javascript学习笔记(五) Array 数组类型介绍
2012/06/19 Javascript
Jquery多选下拉列表插件jquery multiselect功能介绍及使用
2013/05/24 Javascript
document.getElementById获取控件对象为空的解决方法
2013/11/20 Javascript
js sort 二维数组排序的用法小结
2014/01/24 Javascript
Node.js中require的工作原理浅析
2014/06/24 Javascript
JavaScript创建一个object对象并操作对象属性的用法
2015/03/23 Javascript
Vue.js每天必学之内部响应式原理探究
2016/09/07 Javascript
canvas实现简易的圆环进度条效果
2017/02/28 Javascript
ZeroClipboard.js使用一个flash复制多个文本框
2017/06/19 Javascript
JS+Ajax实现百度智能搜索框
2017/08/04 Javascript
Vue.js表单标签中的单选按钮、复选按钮和下拉列表的取值问题
2017/11/22 Javascript
基于jquery trigger函数无法触发a标签的两种解决方法
2018/01/06 jQuery
vue中倒计时组件的实例代码
2018/07/06 Javascript
Vue.js 使用v-cloak后仍显示变量的解决方法
2018/11/19 Javascript
JavaScript实现鼠标移入随机变换颜色
2020/11/24 Javascript
[02:42]DOTA2城市挑战赛收官在即 四强之争风起云涌
2018/06/05 DOTA
python搭建简易服务器分析与实现
2012/12/15 Python
Python计算回文数的方法
2015/03/11 Python
Python 获得13位unix时间戳的方法
2017/10/20 Python
关于Django显示时间你应该知道的一些问题
2017/12/25 Python
Python三种遍历文件目录的方法实例代码
2018/01/19 Python
详解Python中的分组函数groupby和itertools)
2018/07/11 Python
Python使用urllib模块对URL网址中的中文编码与解码实例详解
2020/02/18 Python
Python2 与Python3的版本区别实例分析
2020/03/30 Python
Python 转移文件至云对象存储的方法
2021/02/07 Python
印度尼西亚值得信赖的第一家网店:Bhinneka
2018/07/16 全球购物
迪奥美国官网:Dior美国
2019/12/07 全球购物
继承权公证书
2014/04/09 职场文书
学生会竞选演讲稿怎么写
2014/08/26 职场文书
大学生干部培训心得体会
2016/01/06 职场文书
windows server 2012安装FTP并配置被动模式指定开放端口
2022/06/10 Servers