pytorch+lstm实现的pos示例


Posted in Python onJanuary 14, 2020

学了几天终于大概明白pytorch怎么用了

这个是直接搬运的官方文档的代码

之后会自己试着实现其他nlp的任务

# Author: Robert Guthrie

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)


lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 3
inputs = [autograd.Variable(torch.randn((1, 3)))
     for _ in range(5)] # make a sequence of length 5

# initialize the hidden state.
hidden = (autograd.Variable(torch.randn(1, 1, 3)),
     autograd.Variable(torch.randn((1, 1, 3))))
for i in inputs:
  # Step through the sequence one element at a time.
  # after each step, hidden contains the hidden state.
  out, hidden = lstm(i.view(1, 1, -1), hidden)

# alternatively, we can do the entire sequence all at once.
# the first value returned by LSTM is all of the hidden states throughout
# the sequence. the second is just the most recent hidden state
# (compare the last slice of "out" with "hidden" below, they are the same)
# The reason for this is that:
# "out" will give you access to all hidden states in the sequence
# "hidden" will allow you to continue the sequence and backpropagate,
# by passing it as an argument to the lstm at a later time
# Add the extra 2nd dimension
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
hidden = (autograd.Variable(torch.randn(1, 1, 3)), autograd.Variable(
  torch.randn((1, 1, 3)))) # clean out hidden state
out, hidden = lstm(inputs, hidden)
#print(out)
#print(hidden)

#准备数据
def prepare_sequence(seq, to_ix):
  idxs = [to_ix[w] for w in seq]
  tensor = torch.LongTensor(idxs)
  return autograd.Variable(tensor)

training_data = [
  ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
  ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]
word_to_ix = {}
for sent, tags in training_data:
  for word in sent:
    if word not in word_to_ix:
      word_to_ix[word] = len(word_to_ix)
print(word_to_ix)
tag_to_ix = {"DET": 0, "NN": 1, "V": 2}

# These will usually be more like 32 or 64 dimensional.
# We will keep them small, so we can see how the weights change as we train.
EMBEDDING_DIM = 6
HIDDEN_DIM = 6

#继承自nn.module
class LSTMTagger(nn.Module):

  def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
    super(LSTMTagger, self).__init__()
    self.hidden_dim = hidden_dim

    #一个单词数量到embedding维数的矩阵
    self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

    #传入两个维度参数
    # The LSTM takes word embeddings as inputs, and outputs hidden states
    # with dimensionality hidden_dim.
    self.lstm = nn.LSTM(embedding_dim, hidden_dim)

    #线性layer从隐藏状态空间映射到tag便签
    # The linear layer that maps from hidden state space to tag space
    self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
    self.hidden = self.init_hidden()

  def init_hidden(self):
    # Before we've done anything, we dont have any hidden state.
    # Refer to the Pytorch documentation to see exactly
    # why they have this dimensionality.
    # The axes semantics are (num_layers, minibatch_size, hidden_dim)
    return (autograd.Variable(torch.zeros(1, 1, self.hidden_dim)),
        autograd.Variable(torch.zeros(1, 1, self.hidden_dim)))

  def forward(self, sentence):
    embeds = self.word_embeddings(sentence)
    lstm_out, self.hidden = self.lstm(embeds.view(len(sentence), 1, -1), self.hidden)
    tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
    tag_scores = F.log_softmax(tag_space)
    return tag_scores

#embedding维度,hidden维度,词语数量,标签数量
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))

#optim中存了各种优化算法
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# See what the scores are before training
# Note that element i,j of the output is the score for tag j for word i.
inputs = prepare_sequence(training_data[0][0], word_to_ix)
tag_scores = model(inputs)
print(tag_scores)

for epoch in range(300): # again, normally you would NOT do 300 epochs, it is toy data
  for sentence, tags in training_data:
    # Step 1. Remember that Pytorch accumulates gradients.
    # We need to clear them out before each instance
    model.zero_grad()

    # Also, we need to clear out the hidden state of the LSTM,
    # detaching it from its history on the last instance.
    model.hidden = model.init_hidden()

    # Step 2. Get our inputs ready for the network, that is, turn them into
    # Variables of word indices.
    sentence_in = prepare_sequence(sentence, word_to_ix)
    targets = prepare_sequence(tags, tag_to_ix)

    # Step 3. Run our forward pass.
    tag_scores = model(sentence_in)

    # Step 4. Compute the loss, gradients, and update the parameters by
    # calling optimizer.step()
    loss = loss_function(tag_scores, targets)
    loss.backward()
    optimizer.step()

# See what the scores are after training
inputs = prepare_sequence(training_data[0][0], word_to_ix)
tag_scores = model(inputs)
# The sentence is "the dog ate the apple". i,j corresponds to score for tag j
# for word i. The predicted tag is the maximum scoring tag.
# Here, we can see the predicted sequence below is 0 1 2 0 1
# since 0 is index of the maximum value of row 1,
# 1 is the index of maximum value of row 2, etc.
# Which is DET NOUN VERB DET NOUN, the correct sequence!
print(tag_scores)

以上这篇pytorch+lstm实现的pos示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python模拟登录12306的方法
Dec 30 Python
Python如何生成树形图案
Jan 03 Python
python简单实现操作Mysql数据库
Jan 29 Python
PyQt4编程之让状态栏显示信息的方法
Jun 18 Python
Python单元测试工具doctest和unittest使用解析
Sep 02 Python
简单了解Python3 bytes和str类型的区别和联系
Dec 19 Python
浅谈keras的深度模型训练过程及结果记录方式
Jan 24 Python
python读取与处理netcdf数据方式
Feb 14 Python
在django中使用post方法时,需要增加csrftoken的例子
Mar 13 Python
Ubuntu 20.04安装Pycharm2020.2及锁定到任务栏的问题(小白级操作)
Oct 29 Python
Python解析m3u8拼接下载mp4视频文件的示例代码
Mar 03 Python
Python的flask接收前台的ajax的post数据和get数据的方法
Apr 12 Python
Python中sorted()排序与字母大小写的问题
Jan 14 #Python
Pytorch实现LSTM和GRU示例
Jan 14 #Python
Python生成词云的实现代码
Jan 14 #Python
pytorch-RNN进行回归曲线预测方式
Jan 14 #Python
利用pytorch实现对CIFAR-10数据集的分类
Jan 14 #Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
You might like
ThinkPHP之用户注册登录留言完整实例
2014/07/22 PHP
php封装的连接Mysql类及用法分析
2015/12/10 PHP
自写简单JS判断是否已经弹出页面
2010/10/20 Javascript
js 获取、清空input type="file"的值示例代码
2014/02/19 Javascript
jQuery实现图片预加载效果
2015/11/27 Javascript
js实现网页收藏功能
2015/12/17 Javascript
jQuery easyUI datagrid 增加求和统计行的实现代码
2016/06/01 Javascript
JavaScript来实现打开链接页面的简单实例
2016/06/02 Javascript
AngularJS 中ui-view传参的实例详解
2017/08/25 Javascript
jQuery实现简单日期格式化功能示例
2017/09/19 jQuery
小程序自定义组件实现城市选择功能
2018/07/18 Javascript
在vue项目中集成graphql(vue-ApolloClient)
2018/09/08 Javascript
Vue.js 中的 v-model 指令及绑定表单元素的方法
2018/12/03 Javascript
如何使用less实现随机下雪动画详解
2019/01/02 Javascript
js实现通过开始结束控制的计时器
2019/02/25 Javascript
vue的三种图片引入方式代码实例
2019/11/19 Javascript
JavaScript实现点击图片换背景
2020/11/20 Javascript
pycharm 使用心得(九)解决No Python interpreter selected的问题
2014/06/06 Python
Python编程中的异常处理教程
2015/08/21 Python
用vue.js组件模拟v-model指令实例方法
2019/07/05 Python
python的几种矩阵相乘的公式详解
2019/07/10 Python
python 实现压缩和解压缩的示例
2020/09/22 Python
Hush Puppies澳大利亚官网:舒适的男女休闲和正装鞋
2019/08/24 全球购物
荷兰浴室和卫浴网上商店:Badkamerxxl.nl
2020/10/06 全球购物
mysql的最长数据库名,表名,字段名可以是多长
2014/04/21 面试题
医师定期考核实施方案
2014/05/07 职场文书
法人任命书范本
2014/06/04 职场文书
人事专员岗位职责说明书
2014/07/30 职场文书
普通党员对照检查材料
2014/08/28 职场文书
2014年行政部工作总结
2014/11/19 职场文书
租赁协议书
2015/01/27 职场文书
2015年圣诞节活动总结
2015/03/24 职场文书
最美劳动诗,致敬所有的劳动者!
2019/07/12 职场文书
CSS3点击按钮圆形进度打钩效果的实现代码
2021/03/30 HTML / CSS
Win10系统下配置Java环境变量
2021/06/13 Java/Android
德劲DE1105机评
2022/04/05 无线电