pytorch中的embedding词向量的使用方法


Posted in Python onAugust 18, 2019

Embedding

词嵌入在 pytorch 中非常简单,只需要调用 torch.nn.Embedding(m, n) 就可以了,m 表示单词的总数目,n 表示词嵌入的维度,其实词嵌入就相当于是一个大矩阵,矩阵的每一行表示一个单词。

emdedding初始化

默认是随机初始化的

import torch
from torch import nn
from torch.autograd import Variable
# 定义词嵌入
embeds = nn.Embedding(2, 5) # 2 个单词,维度 5
# 得到词嵌入矩阵,开始是随机初始化的
torch.manual_seed(1)
embeds.weight
# 输出结果:
Parameter containing:
-0.8923 -0.0583 -0.1955 -0.9656 0.4224
 0.2673 -0.4212 -0.5107 -1.5727 -0.1232
[torch.FloatTensor of size 2x5]

如果从使用已经训练好的词向量,则采用

pretrained_weight = np.array(args.pretrained_weight) # 已有词向量的numpy
self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))

embed的读取

读取一个向量。

注意参数只能是LongTensor型的

# 访问第 50 个词的词向量
embeds = nn.Embedding(100, 10)
embeds(Variable(torch.LongTensor([50])))
# 输出:
Variable containing:
 0.6353 1.0526 1.2452 -1.8745 -0.1069 0.1979 0.4298 -0.3652 -0.7078 0.2642
[torch.FloatTensor of size 1x10]

读取多个向量。

输入为两个维度(batch的大小,每个batch的单词个数),输出则在两个维度上加上词向量的大小。

Input: LongTensor (N, W), N = mini-batch, W = number of indices to extract per mini-batch
Output: (N, W, embedding_dim)

见代码

# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
# 每批取两组,每组四个单词
input = Variable(torch.LongTensor([[1,2,4,5],[4,3,2,9]]))
a = embedding(input) # 输出2*4*3
a[0],a[1]

输出为:

(Variable containing:
 -1.2603 0.4337 0.4181
 0.4458 -0.1987 0.4971
 -0.5783 1.3640 0.7588
 0.4956 -0.2379 -0.7678
 [torch.FloatTensor of size 4x3], Variable containing:
 -0.5783 1.3640 0.7588
 -0.5313 -0.3886 -0.6110
 0.4458 -0.1987 0.4971
 -1.3768 1.7323 0.4816
 [torch.FloatTensor of size 4x3])

以上这篇pytorch中的embedding词向量的使用方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的即时标记项目练习笔记
Sep 18 Python
基python实现多线程网页爬虫
Sep 06 Python
Python实现简单登录验证
Apr 13 Python
Python模拟简单电梯调度算法示例
Aug 20 Python
Python字符串的常见操作实例小结
Apr 08 Python
Python+OpenCV采集本地摄像头的视频
Apr 25 Python
scrapy-redis源码分析之发送POST请求详解
May 15 Python
Python线程threading模块用法详解
Feb 26 Python
python+requests接口自动化框架的实现
Aug 31 Python
python绘制分布折线图的示例
Sep 24 Python
python 使用三引号时容易犯的小错误
Oct 21 Python
利用For循环遍历Python字典的三种方法实例
Mar 25 Python
Pytorch加载部分预训练模型的参数实例
Aug 18 #Python
在pytorch中查看可训练参数的例子
Aug 18 #Python
浅析PyTorch中nn.Module的使用
Aug 18 #Python
关于PyTorch 自动求导机制详解
Aug 18 #Python
pytorch神经网络之卷积层与全连接层参数的设置方法
Aug 18 #Python
pytorch numpy list类型之间的相互转换实例
Aug 18 #Python
对Pytorch中nn.ModuleList 和 nn.Sequential详解
Aug 18 #Python
You might like
php中文件上传的安全问题
2006/10/09 PHP
提示Trying to clone an uncloneable object of class Imagic的解决
2011/10/27 PHP
php实现过滤表单提交中html标签的方法
2014/10/17 PHP
javascript 火狐(firefox)不显示本地图片问题解决
2008/07/05 Javascript
js继承 Base类的源码解析
2008/12/30 Javascript
JavaScript Object的extend是一个常用的功能
2009/12/02 Javascript
自定义jQuery选项卡插件实例
2013/03/27 Javascript
JavaScript对内存分配及管理机制详细解析
2013/11/11 Javascript
jquery ajax jsonp跨域调用实例代码
2013/12/11 Javascript
如何让浏览器支持jquery ajax load 前进、后退功能
2014/06/12 Javascript
如何减少浏览器的reflow和repaint
2015/02/26 Javascript
js实现网页收藏功能
2015/12/17 Javascript
原生js获取iframe中dom元素--父子页面相互获取对方dom元素的方法
2016/08/05 Javascript
微信小程序Server端环境配置详解(SSL, Nginx HTTPS,TLS 1.2 升级)
2017/01/12 Javascript
ES5学习教程之Array对象
2017/04/01 Javascript
CSS3+JavaScript实现翻页幻灯片效果
2017/06/28 Javascript
vue2导航根据路由传值,而改变导航内容的实例
2017/11/10 Javascript
在Vue组件中使用 TypeScript的方法
2018/02/28 Javascript
Vue入门之animate过渡动画效果
2018/04/08 Javascript
jQuery+PHP+Ajax实现动态数字统计展示功能
2019/12/25 jQuery
uni-app如何页面传参数的几种方法总结
2020/04/28 Javascript
vue项目页面嵌入代码块vue-prism-editor的实现
2020/10/30 Javascript
Python多线程学习资料
2012/12/19 Python
Python Deque 模块使用详解
2014/07/04 Python
python 2.6.6升级到python 2.7.x版本的方法
2016/10/09 Python
Python模块WSGI使用详解
2018/02/02 Python
numpy中实现二维数组按照某列、某行排序的方法
2018/04/04 Python
python中字符串的操作方法大全
2018/06/03 Python
用python3 返回鼠标位置的实现方法(带界面)
2019/07/05 Python
python飞机大战pygame游戏背景设计详解
2019/12/17 Python
python带参数打包exe及调用方式
2019/12/21 Python
零基础小白多久能学会python
2020/06/22 Python
CSS3中Animation动画属性用法详解
2016/07/04 HTML / CSS
机械专业求职信
2014/05/25 职场文书
助人为乐模范事迹材料
2014/06/02 职场文书
2016中秋节月饼促销广告语
2016/01/28 职场文书