pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python压缩文件夹内所有文件为zip文件的方法
Jun 20 Python
离线安装Pyecharts的步骤以及依赖包流程
Apr 23 Python
用python结合jieba和wordcloud实现词云效果
Sep 05 Python
Python打印“菱形”星号代码方法
Feb 05 Python
使用Flask-Cache缓存实现给Flask提速的方法详解
Jun 11 Python
selenium 安装与chromedriver安装的方法步骤
Jun 12 Python
python里dict变成list实例方法
Jun 26 Python
Python基于OpenCV实现人脸检测并保存
Jul 23 Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 Python
Python导入模块包原理及相关注意事项
Mar 25 Python
使用Python合成图片的实现代码(图片添加个性化文本,图片上叠加其他图片)
Apr 30 Python
PyCharm MySQL可视化Database配置过程图解
Jun 09 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
php解决抢购秒杀抽奖等大流量并发入库导致的库存负数的问题
2014/06/19 PHP
php魔术函数__call()用法实例分析
2015/02/13 PHP
浅析Laravel5中队列的配置及使用
2016/08/04 PHP
php curl操作API接口类完整示例
2019/05/21 PHP
laravel实现图片上传预览,及编辑时可更换图片,并实时变化的例子
2019/11/14 PHP
PHP const定义常量及global定义全局常量实例解析
2020/05/28 PHP
Javascript与vbscript数据共享
2007/01/09 Javascript
jQuery html()等方法介绍
2009/11/18 Javascript
Javascript基础知识(一)核心基础语法与事件模型
2014/09/29 Javascript
最流行的Node.js精简型和全栈型开发框架介绍
2015/02/26 Javascript
Adapter适配器模式在JavaScript设计模式编程中的运用分析
2016/05/18 Javascript
javaScript手机号码校验工具类PhoneUtils详解
2017/12/08 Javascript
JavaScript函数、闭包、原型、面向对象学习笔记
2018/09/06 Javascript
如何使用electron-builder及electron-updater给项目配置自动更新
2018/12/24 Javascript
JSON.stringify()方法讲解
2019/01/31 Javascript
详解VUE前端按钮权限控制
2019/04/26 Javascript
新手快速入门微信小程序组件库 iView Weapp
2019/06/24 Javascript
Python list操作用法总结
2015/11/10 Python
Python中方法链的使用方法
2016/02/23 Python
Pycharm debug调试时带参数过程解析
2020/02/03 Python
python GUI库图形界面开发之PyQt5滚动条控件QScrollBar详细使用方法与实例
2020/03/06 Python
python应用Axes3D绘图(批量梯度下降算法)
2020/03/25 Python
在echarts中图例legend和坐标系grid实现左右布局实例
2020/05/16 Python
新手常见Python错误及异常解决处理方案
2020/06/18 Python
解决Pytorch自定义层出现多Variable共享内存错误问题
2020/06/28 Python
Pycharm打开已有项目配置python环境的方法
2020/07/03 Python
scrapy-redis分布式爬虫的搭建过程(理论篇)
2020/09/29 Python
大课间活动制度
2014/01/18 职场文书
教师远程培训感言
2014/03/06 职场文书
在校大学生的职业生涯规划书
2014/03/14 职场文书
道路建设实施方案
2014/03/18 职场文书
个人先进材料范文
2014/12/30 职场文书
运动会闭幕词
2015/01/28 职场文书
公务员政审个人总结
2015/02/12 职场文书
关于感恩的歌曲整理(8首)
2019/08/14 职场文书
MySQL分库分表与分区的入门指南
2021/04/22 MySQL