pytorch实现seq2seq时对loss进行mask的方式


Posted in Python onFebruary 18, 2020

如何对loss进行mask

pytorch官方教程中有一个Chatbot教程,就是利用seq2seq和注意力机制实现的,感觉和机器翻译没什么不同啊,如果对话中一句话有下一句,那么就把这一对句子加入模型进行训练。其中在训练阶段,损失函数通常需要进行mask操作,因为一个batch中句子的长度通常是不一样的,一个batch中不足长度的位置需要进行填充(pad)补0,最后生成句子计算loss时需要忽略那些原本是pad的位置的值,即只保留mask中值为1位置的值,忽略值为0位置的值,具体演示如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PAD_token = 0

首先是pad函数和建立mask矩阵,矩阵的维度应该和目标一致。

def zeroPadding(l, fillvalue=PAD_token):
 # 输入:[[1, 1, 1], [2, 2], [3]]
 # 返回:[(1, 2, 3), (1, 2, 0), (1, 0, 0)] 返回已经是转置后的 [L, B]
 return list(itertools.zip_longest(*l, fillvalue=fillvalue))


def binaryMatrix(l):
 # 将targets里非pad部分标记为1,pad部分标记为0
 m = []
 for i, seq in enumerate(l):
 m.append([])
 for token in seq:
  if token == PAD_token:
  m[i].append(0)
  else:
  m[i].append(1)
 return m

假设现在输入一个batch中有三个句子,我们按照长度从大到小排好序,LSTM或是GRU的输入和输出我们需要利用pack_padded_sequence和pad_packed_sequence进行打包和解包,感觉也是在进行mask操作。

inputs = [[1, 2, 3], [4, 5], [6]] # 输入句,一个batch,需要按照长度从大到小排好序
inputs_lengths = [3, 2, 1]
targets = [[1, 2], [1, 2, 3], [1]] # 目标句,这里的长度是不确定的,mask是针对targets的
inputs_batch = torch.LongTensor(zeroPadding(inputs))
inputs_lengths = torch.LongTensor(inputs_lengths)
targets_batch = torch.LongTensor(zeroPadding(targets))
targets_mask = torch.ByteTensor(binaryMatrix(zeroPadding(targets))) # 注意这里是ByteTensor
print(inputs_batch)
print(targets_batch)
print(targets_mask)

打印后结果如下,可见维度统一变成了[L, B],并且mask和target长得一样。另外,seq2seq模型处理时for循环每次读取一行,预测下一行的值(即[B, L]时的一列预测下一列)。

tensor([[ 1, 4, 6],
 [ 2, 5, 0],
 [ 3, 0, 0]])
tensor([[ 1, 1, 1],
 [ 2, 2, 0],
 [ 0, 3, 0]])
tensor([[ 1, 1, 1],
 [ 1, 1, 0],
 [ 0, 1, 0]], dtype=torch.uint8)

现在假设我们将inputs输入模型后,模型读入sos后预测的第一行为outputs1, 维度为[B, vocab_size],即每个词在词汇表中的概率,模型输出之前需要softmax。

outputs1 = torch.FloatTensor([[0.2, 0.1, 0.7], [0.3, 0.6, 0.1], [0.4, 0.5, 0.1]])
print(outputs1)
tensor([[ 0.2000, 0.1000, 0.7000],
 [ 0.3000, 0.6000, 0.1000],
 [ 0.4000, 0.5000, 0.1000]])

先看看两个函数

torch.gather(input, dim, index, out=None)->Tensor

沿着某个轴,按照指定维度采集数据,对于3维数据,相当于进行如下操作:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

比如在这里,在第1维,选第二个元素。

# 收集每行的第2个元素
temp = torch.gather(outputs1, 1, torch.LongTensor([[1], [1], [1]]))
print(temp)
tensor([[ 0.1000],
 [ 0.6000],
 [ 0.5000]])

torch.masked_select(input, mask, out=None)->Tensor

根据mask(ByteTensor)选取对应位置的值,返回一维张量。

例如在这里我们选取temp大于等于0.5的值。

mask = temp.ge(0.5) # 大于等于0.5
print(mask)
print(torch.masked_select(temp, temp.ge(0.5)))
tensor([[ 0],
 [ 1],
 [ 1]], dtype=torch.uint8)
tensor([ 0.6000, 0.5000])

然后我们就可以计算loss了,这里是负对数损失函数,之前模型的输出要进行softmax。

# 计算一个batch内的平均负对数似然损失,即只考虑mask为1的元素
def maskNLLLoss(inp, target, mask):
 nTotal = mask.sum()
 # 收集目标词的概率,并取负对数
 crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
 # 只保留mask中值为1的部分,并求均值
 loss = crossEntropy.masked_select(mask).mean()
 loss = loss.to(DEVICE)
 return loss, nTotal.item()

这里我们计算第一行的平均损失。

# 计算预测的第一行和targets的第一行的loss
maskNLLLoss(outputs1, targets_batch[0], targets_mask[0])

(tensor(1.1689, device='cuda:0'), 3)

最后进行最后把所有行的loss累加起来变为total_loss.backward()进行反向传播就可以了。

以上这篇pytorch实现seq2seq时对loss进行mask的方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python对象体系深入分析
Oct 28 Python
详解Python中的日志模块logging
Jun 19 Python
python使用str & repr转换字符串
Oct 13 Python
python 调用win32pai 操作cmd的方法
May 28 Python
dataframe设置两个条件取值的实例
Apr 12 Python
selenium+python实现自动登录脚本
Apr 22 Python
利用python求积分的实例
Jul 03 Python
django的聚合函数和aggregate、annotate方法使用详解
Jul 23 Python
Python 实用技巧之利用Shell通配符做字符串匹配
Aug 23 Python
利用 Flask 动态展示 Pyecharts 图表数据方法小结
Sep 04 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 Python
python中pandas.read_csv()函数的深入讲解
Mar 29 Python
python多项式拟合之np.polyfit 和 np.polyld详解
Feb 18 #Python
tensorflow 分类损失函数使用小记
Feb 18 #Python
python如何把字符串类型list转换成list
Feb 18 #Python
python计算波峰波谷值的方法(极值点)
Feb 18 #Python
Python表达式的优先级详解
Feb 18 #Python
使用Tkinter制作信息提示框
Feb 18 #Python
Python中import导入不同目录的模块方法详解
Feb 18 #Python
You might like
《一拳超人》埼玉一拳下去,他们存在了800年毫无意义!
2020/03/02 日漫
PHP If Else(elsefi) 语句
2013/04/07 PHP
PHP使用SOAP调用.net的WebService数据
2013/11/12 PHP
一个严格的PHP Session会话超时时间设置方法
2014/06/10 PHP
PHP+jQuery+Ajax实现分页效果 jPaginate插件的应用
2015/10/09 PHP
在WordPress的文章编辑器中设置默认内容的方法
2015/12/29 PHP
用 JavaScript 迁移目录
2006/12/18 Javascript
js直接编辑当前cookie的脚本
2008/09/14 Javascript
jquery实现div阴影效果示例代码
2013/09/16 Javascript
超级好用的jQuery圆角插件 Corner速成
2014/08/31 Javascript
jQuery中after()方法用法实例
2014/12/25 Javascript
JS+CSS实现下拉列表框美化效果(3款)
2015/08/15 Javascript
浅谈Javascript中Object与Function对象
2015/09/26 Javascript
js实现跨域访问的三种方法
2015/12/09 Javascript
js和jquery实现监听键盘事件示例代码
2020/06/24 Javascript
JS组件系列之Bootstrap Icon图标选择组件
2016/01/28 Javascript
AngularJS 中使用Swiper制作滚动图不能滑动的解决方法
2016/11/15 Javascript
Bootstrap table使用方法记录
2017/08/23 Javascript
JS实现table表格内针对某列内容进行即时搜索筛选功能
2018/05/11 Javascript
JS数组实现分类统计实例代码
2018/09/30 Javascript
[03:30]DOTA2完美“圣”典精彩集锦
2016/12/27 DOTA
PyTorch CNN实战之MNIST手写数字识别示例
2018/05/29 Python
Python小程序之在图片上加入数字的代码
2019/11/26 Python
Keras在训练期间可视化训练误差和测试误差实例
2020/06/16 Python
CSS 说明横向进度条最后显示文字的实现代码
2020/11/10 HTML / CSS
德国孕妇装和婴童服装网上商店:bellybutton
2018/04/12 全球购物
编写类String的构造函数、析构函数和赋值函数
2012/05/29 面试题
怎么写好自荐信
2013/10/30 职场文书
电大本科自我鉴定
2014/02/05 职场文书
青年志愿者先进事迹
2014/05/06 职场文书
学习方法演讲稿
2014/05/10 职场文书
物流管理专业求职信
2014/05/29 职场文书
班级课外活动总结
2014/07/09 职场文书
赔偿协议书
2015/01/27 职场文书
Win11安装受阻怎么办? Windows11安装问题与解决方案汇总
2021/11/21 数码科技
windows10 家庭版下FTP服务器搭建教程
2022/08/05 Servers