pytorch实现用CNN和LSTM对文本进行分类方式


Posted in Python onJanuary 08, 2020

model.py:

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
 
class TextRNN(nn.Module):
  """文本分类,RNN模型"""
  def __init__(self):
    super(TextRNN, self).__init__()
    # 三个待输入的数据
    self.embedding = nn.Embedding(5000, 64) # 进行词嵌入
    # self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
    self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
    self.f1 = nn.Sequential(nn.Linear(256,128),
                nn.Dropout(0.8),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128,10),
                nn.Softmax())
 
  def forward(self, x):
    x = self.embedding(x)
    x,_ = self.rnn(x)
    x = F.dropout(x,p=0.8)
    x = self.f1(x[:,-1,:])
    return self.f2(x)
 
class TextCNN(nn.Module):
  def __init__(self):
    super(TextCNN, self).__init__()
    self.embedding = nn.Embedding(5000,64)
    self.conv = nn.Conv1d(64,256,5)
    self.f1 = nn.Sequential(nn.Linear(256*596, 128),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128, 10),
                nn.Softmax())
  def forward(self, x):
    x = self.embedding(x)
    x = x.detach().numpy()
    x = np.transpose(x,[0,2,1])
    x = torch.Tensor(x)
    x = Variable(x)
    x = self.conv(x)
    x = x.view(-1,256*596)
    x = self.f1(x)
    return self.f2(x)

train.py:

# coding: utf-8
 
from __future__ import print_function
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import os
 
import numpy as np
 
from model import TextRNN,TextCNN
from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
 
base_dir = 'cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 
 
def train():
  x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#获取训练数据每个字的id和对应标签的oe-hot形式
  x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600)
  #使用LSTM或者CNN
  model = TextRNN()
  # model = TextCNN()
  #选择损失函数
  Loss = nn.MultiLabelSoftMarginLoss()
  # Loss = nn.BCELoss()
  # Loss = nn.MSELoss()
  optimizer = optim.Adam(model.parameters(),lr=0.001)
  best_val_acc = 0
  for epoch in range(1000):
    batch_train = batch_iter(x_train, y_train,100)
    for x_batch, y_batch in batch_train:
      x = np.array(x_batch)
      y = np.array(y_batch)
      x = torch.LongTensor(x)
      y = torch.Tensor(y)
      # y = torch.LongTensor(y)
      x = Variable(x)
      y = Variable(y)
      out = model(x)
      loss = Loss(out,y)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy())
    #对模型进行验证
    if (epoch+1)%20 == 0:
      batch_val = batch_iter(x_val, y_val, 100)
      for x_batch, y_batch in batch_train:
        x = np.array(x_batch)
        y = np.array(y_batch)
        x = torch.LongTensor(x)
        y = torch.Tensor(y)
        # y = torch.LongTensor(y)
        x = Variable(x)
        y = Variable(y)
        out = model(x)
        loss = Loss(out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy())
        if accracy > best_val_acc:
          torch.save(model.state_dict(),'model_params.pkl')
          best_val_acc = accracy
        print(accracy)
 
if __name__ == '__main__':
  #获取文本的类别及其对应id的字典
  categories, cat_to_id = read_category()
  #获取训练文本中所有出现过的字及其所对应的id
  words, word_to_id = read_vocab(vocab_dir)
  #获取字数
  vocab_size = len(words)
  train()

test.py:

# coding: utf-8
 
from __future__ import print_function
 
import os
import tensorflow.contrib.keras as kr
import torch
from torch import nn
from cnews_loader import read_category, read_vocab
from model import TextRNN
from torch.autograd import Variable
import numpy as np
try:
  bool(type(unicode))
except NameError:
  unicode = str
 
base_dir = 'cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 
class TextCNN(nn.Module):
  def __init__(self):
    super(TextCNN, self).__init__()
    self.embedding = nn.Embedding(5000,64)
    self.conv = nn.Conv1d(64,256,5)
    self.f1 = nn.Sequential(nn.Linear(152576, 128),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128, 10),
                nn.Softmax())
  def forward(self, x):
    x = self.embedding(x)
    x = x.detach().numpy()
    x = np.transpose(x,[0,2,1])
    x = torch.Tensor(x)
    x = Variable(x)
    x = self.conv(x)
    x = x.view(-1,152576)
    x = self.f1(x)
    return self.f2(x)
 
class CnnModel:
  def __init__(self):
    self.categories, self.cat_to_id = read_category()
    self.words, self.word_to_id = read_vocab(vocab_dir)
    self.model = TextCNN()
    self.model.load_state_dict(torch.load('model_params.pkl'))
 
  def predict(self, message):
    # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
    content = unicode(message)
    data = [self.word_to_id[x] for x in content if x in self.word_to_id]
    data = kr.preprocessing.sequence.pad_sequences([data],600)
    data = torch.LongTensor(data)
    y_pred_cls = self.model(data)
    class_index = torch.argmax(y_pred_cls[0]).item()
    return self.categories[class_index]
 
class RnnModel:
  def __init__(self):
    self.categories, self.cat_to_id = read_category()
    self.words, self.word_to_id = read_vocab(vocab_dir)
    self.model = TextRNN()
    self.model.load_state_dict(torch.load('model_rnn_params.pkl'))
 
  def predict(self, message):
    # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
    content = unicode(message)
    data = [self.word_to_id[x] for x in content if x in self.word_to_id]
    data = kr.preprocessing.sequence.pad_sequences([data], 600)
    data = torch.LongTensor(data)
    y_pred_cls = self.model(data)
    class_index = torch.argmax(y_pred_cls[0]).item()
    return self.categories[class_index]
 
 
if __name__ == '__main__':
  model = CnnModel()
  # model = RnnModel()
  test_demo = ['湖人助教力助科比恢复手感 他也是阿泰的精神导师新浪体育讯记者戴高乐报道 上赛季,科比的右手食指遭遇重创,他的投篮手感也因此大受影响。不过很快科比就调整了自己的投篮手型,并通过这一方式让自己的投篮命中率回升。而在这科比背后,有一位特别助教对科比帮助很大,他就是查克·珀森。珀森上赛季担任湖人的特别助教,除了帮助科比调整投篮手型之外,他的另一个重要任务就是担任阿泰的精神导师。来到湖人队之后,阿泰收敛起了暴躁的脾气,成为湖人夺冠路上不可或缺的一员,珀森的“心灵按摩”功不可没。经历了上赛季的成功之后,珀森本赛季被“升职”成为湖人队的全职助教,每场比赛,他都会坐在球场边,帮助禅师杰克逊一起指挥湖人球员在场上拼杀。对于珀森的工作,禅师非常欣赏,“查克非常善于分析问题,”菲尔·杰克逊说,“他总是在寻找问题的答案,同时也在找造成这一问题的原因,这是我们都非常乐于看到的。我会在平时把防守中出现的一些问题交给他,然后他会通过组织球员练习找到解决的办法。他在球员时代曾是一名很好的外线投手,不过现在他与内线球员的配合也相当不错。',
         '弗老大被裁美国媒体看热闹“特权”在中国像蠢蛋弗老大要走了。虽然他只在首钢男篮效力了13天,而且表现毫无亮点,大大地让球迷和俱乐部失望了,但就像中国人常说的“好聚好散”,队友还是友好地与他告别,俱乐部与他和平分手,球迷还请他留下了在北京的最后一次签名。相比之下,弗老大的同胞美国人却没那么“宽容”。他们嘲讽这位NBA前巨星的英雄迟暮,批评他在CBA的业余表现,还惊讶于中国人的“大方”。今天,北京首钢俱乐部将与弗朗西斯继续商讨解约一事。从昨日的进展来看,双方可以做到“买卖不成人意在”,但回到美国后,恐怕等待弗朗西斯的就没有这么轻松的环境了。进展@北京昨日与队友告别 最后一次为球迷签名弗朗西斯在13天里为首钢队打了4场比赛,3场的得分为0,只有一场得了2分。昨天是他来到北京的第14天,虽然他与首钢还未正式解约,但双方都明白“缘分已尽”。下午,弗朗西斯来到首钢俱乐部与队友们告别。弗朗西斯走到队友身边,依次与他们握手拥抱。“你们都对我很好,安排的条件也很好,我很喜欢这支球队,想融入你们,但我现在真的很不适应。希望你们']
  for i in test_demo:
    print(i,":",model.predict(i))

cnews_loader.py:

# coding: utf-8
 
import sys
from collections import Counter
 
import numpy as np
import tensorflow.contrib.keras as kr
 
if sys.version_info[0] > 2:
  is_py3 = True
else:
  reload(sys)
  sys.setdefaultencoding("utf-8")
  is_py3 = False
 
 
def native_word(word, encoding='utf-8'):
  """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
  if not is_py3:
    return word.encode(encoding)
  else:
    return word
 
 
def native_content(content):
  if not is_py3:
    return content.decode('utf-8')
  else:
    return content
 
 
def open_file(filename, mode='r'):
  """
  常用文件操作,可在python2和python3间切换.
  mode: 'r' or 'w' for read or write
  """
  if is_py3:
    return open(filename, mode, encoding='utf-8', errors='ignore')
  else:
    return open(filename, mode)
 
 
def read_file(filename):
  """读取文件数据"""
  contents, labels = [], []
  with open_file(filename) as f:
    for line in f:
      try:
        label, content = line.strip().split('\t')
        if content:
          contents.append(list(native_content(content)))
          labels.append(native_content(label))
      except:
        pass
  return contents, labels
 
 
def build_vocab(train_dir, vocab_dir, vocab_size=5000):
  """根据训练集构建词汇表,存储"""
  data_train, _ = read_file(train_dir)
 
  all_data = []
  for content in data_train:
    all_data.extend(content)
 
  counter = Counter(all_data)
  count_pairs = counter.most_common(vocab_size - 1)
  words, _ = list(zip(*count_pairs))
  # 添加一个 <PAD> 来将所有文本pad为同一长度
  words = ['<PAD>'] + list(words)
  open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
 
 
def read_vocab(vocab_dir):
  """读取词汇表"""
  # words = open_file(vocab_dir).read().strip().split('\n')
  with open_file(vocab_dir) as fp:
    # 如果是py2 则每个值都转化为unicode
    words = [native_content(_.strip()) for _ in fp.readlines()]
  word_to_id = dict(zip(words, range(len(words))))
  return words, word_to_id
 
 
def read_category():
  """读取分类目录,固定"""
  categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
 
  categories = [native_content(x) for x in categories]
 
  cat_to_id = dict(zip(categories, range(len(categories))))
 
  return categories, cat_to_id
 
 
def to_words(content, words):
  """将id表示的内容转换为文字"""
  return ''.join(words[x] for x in content)
 
 
def process_file(filename, word_to_id, cat_to_id, max_length=600):
  """将文件转换为id表示"""
  contents, labels = read_file(filename)#读取训练数据的每一句话及其所对应的类别
  data_id, label_id = [], []
  for i in range(len(contents)):
    data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#将每句话id化
    label_id.append(cat_to_id[labels[i]])#每句话对应的类别的id
  #
  # # 使用keras提供的pad_sequences来将文本pad为固定长度
  x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
  y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示
  #
  return x_pad, y_pad
 
 
def batch_iter(x, y, batch_size=64):
  """生成批次数据"""
  data_len = len(x)
  num_batch = int((data_len - 1) / batch_size) + 1
 
  indices = np.random.permutation(np.arange(data_len))
  x_shuffle = x[indices]
  y_shuffle = y[indices]
 
  for i in range(num_batch):
    start_id = i * batch_size
    end_id = min((i + 1) * batch_size, data_len)
    yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

以上这篇pytorch实现用CNN和LSTM对文本进行分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解决字典中的值是列表问题的方法
Mar 04 Python
python正常时间和unix时间戳相互转换的方法
Apr 23 Python
ubuntu系统下 python链接mysql数据库的方法
Jan 09 Python
python模块之time模块(实例讲解)
Sep 13 Python
浅谈numpy数组的几种排序方式
Dec 15 Python
Puppeteer使用示例详解
Jun 20 Python
详解如何用python实现一个简单下载器的服务端和客户端
Oct 28 Python
使用Python生成200个激活码的实现方法
Nov 22 Python
keras获得model中某一层的某一个Tensor的输出维度教程
Jan 24 Python
Python os模块常用方法和属性总结
Feb 20 Python
详解python常用命令行选项与环境变量
Feb 20 Python
Python基本知识点总结
Apr 07 Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
You might like
PHP防注入安全代码
2008/04/09 PHP
PHP实现批量修改文件后缀名的方法
2015/07/30 PHP
thinkPHP框架通过Redis实现增删改查操作的方法详解
2019/05/13 PHP
ThinkPHP5.1的权限控制怎么写?分享一个AUTH权限控制
2021/03/09 PHP
javascript之更有效率的字符串替换
2008/08/02 Javascript
通过event对象的fromElement属性解决热区设置主实体的一个bug
2008/12/22 Javascript
jQuery+CSS 实现的超Sexy下拉菜单
2010/01/17 Javascript
jQuery判断元素是否是隐藏的代码
2011/04/24 Javascript
自己动手制作jquery插件之自动添加删除行功能介绍
2011/10/14 Javascript
JavaScript控制图片加载完成后调用回调函数的方法
2015/03/20 Javascript
nodejs实现HTTPS发起POST请求
2015/04/23 NodeJs
全面解析Bootstrap排版使用方法(标题)
2015/11/30 Javascript
bootstrap treeview 树形菜单带复选框及级联选择功能
2018/06/08 Javascript
vue-cli配置环境变量的方法
2018/07/09 Javascript
原生JS检测CSS3动画是否结束的方法详解
2019/01/27 Javascript
微信小程序用户拒绝授权的处理方法详解
2019/09/20 Javascript
微信小程序实现多选框功能的实例代码
2020/06/24 Javascript
[01:03:47]VP vs NewBee Supermajor 胜者组 BO3 第一场 6.5
2018/06/06 DOTA
Python计算程序运行时间的方法
2014/12/13 Python
Python装饰器使用示例及实际应用例子
2015/03/06 Python
python 如何快速找出两个电子表中数据的差异
2017/05/26 Python
python3+PyQt5实现文档打印功能
2018/04/24 Python
对Django中内置的User模型实例详解
2019/08/16 Python
Python 依赖库太多了该如何管理
2019/11/08 Python
解决django后台管理界面添加中文内容乱码问题
2019/11/15 Python
Windows+Anaconda3+PyTorch+PyCharm的安装教程图文详解
2020/04/03 Python
来自世界各地的优质葡萄酒:VineShop24
2018/07/09 全球购物
什么是数组名
2012/05/10 面试题
绩效工资实施方案
2014/03/15 职场文书
医院安全生产月活动总结
2014/07/05 职场文书
计生工作先进事迹
2014/08/15 职场文书
科技工作者先进事迹
2014/08/16 职场文书
党员民主生活会个人整改措施材料
2014/09/16 职场文书
乡镇党的群众路线教育实践活动总结报告
2014/10/30 职场文书
病危通知单
2015/04/17 职场文书
Java面试题冲刺第十六天--消息队列
2021/08/07 面试题