pytorch实现用CNN和LSTM对文本进行分类方式


Posted in Python onJanuary 08, 2020

model.py:

#!/usr/bin/python
# -*- coding: utf-8 -*-
 
import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
 
class TextRNN(nn.Module):
  """文本分类,RNN模型"""
  def __init__(self):
    super(TextRNN, self).__init__()
    # 三个待输入的数据
    self.embedding = nn.Embedding(5000, 64) # 进行词嵌入
    # self.rnn = nn.LSTM(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
    self.rnn = nn.GRU(input_size=64, hidden_size=128, num_layers=2, bidirectional=True)
    self.f1 = nn.Sequential(nn.Linear(256,128),
                nn.Dropout(0.8),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128,10),
                nn.Softmax())
 
  def forward(self, x):
    x = self.embedding(x)
    x,_ = self.rnn(x)
    x = F.dropout(x,p=0.8)
    x = self.f1(x[:,-1,:])
    return self.f2(x)
 
class TextCNN(nn.Module):
  def __init__(self):
    super(TextCNN, self).__init__()
    self.embedding = nn.Embedding(5000,64)
    self.conv = nn.Conv1d(64,256,5)
    self.f1 = nn.Sequential(nn.Linear(256*596, 128),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128, 10),
                nn.Softmax())
  def forward(self, x):
    x = self.embedding(x)
    x = x.detach().numpy()
    x = np.transpose(x,[0,2,1])
    x = torch.Tensor(x)
    x = Variable(x)
    x = self.conv(x)
    x = x.view(-1,256*596)
    x = self.f1(x)
    return self.f2(x)

train.py:

# coding: utf-8
 
from __future__ import print_function
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import os
 
import numpy as np
 
from model import TextRNN,TextCNN
from cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocab
 
base_dir = 'cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 
 
def train():
  x_train, y_train = process_file(train_dir, word_to_id, cat_to_id,600)#获取训练数据每个字的id和对应标签的oe-hot形式
  x_val, y_val = process_file(val_dir, word_to_id, cat_to_id,600)
  #使用LSTM或者CNN
  model = TextRNN()
  # model = TextCNN()
  #选择损失函数
  Loss = nn.MultiLabelSoftMarginLoss()
  # Loss = nn.BCELoss()
  # Loss = nn.MSELoss()
  optimizer = optim.Adam(model.parameters(),lr=0.001)
  best_val_acc = 0
  for epoch in range(1000):
    batch_train = batch_iter(x_train, y_train,100)
    for x_batch, y_batch in batch_train:
      x = np.array(x_batch)
      y = np.array(y_batch)
      x = torch.LongTensor(x)
      y = torch.Tensor(y)
      # y = torch.LongTensor(y)
      x = Variable(x)
      y = Variable(y)
      out = model(x)
      loss = Loss(out,y)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      accracy = np.mean((torch.argmax(out,1)==torch.argmax(y,1)).numpy())
    #对模型进行验证
    if (epoch+1)%20 == 0:
      batch_val = batch_iter(x_val, y_val, 100)
      for x_batch, y_batch in batch_train:
        x = np.array(x_batch)
        y = np.array(y_batch)
        x = torch.LongTensor(x)
        y = torch.Tensor(y)
        # y = torch.LongTensor(y)
        x = Variable(x)
        y = Variable(y)
        out = model(x)
        loss = Loss(out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accracy = np.mean((torch.argmax(out, 1) == torch.argmax(y, 1)).numpy())
        if accracy > best_val_acc:
          torch.save(model.state_dict(),'model_params.pkl')
          best_val_acc = accracy
        print(accracy)
 
if __name__ == '__main__':
  #获取文本的类别及其对应id的字典
  categories, cat_to_id = read_category()
  #获取训练文本中所有出现过的字及其所对应的id
  words, word_to_id = read_vocab(vocab_dir)
  #获取字数
  vocab_size = len(words)
  train()

test.py:

# coding: utf-8
 
from __future__ import print_function
 
import os
import tensorflow.contrib.keras as kr
import torch
from torch import nn
from cnews_loader import read_category, read_vocab
from model import TextRNN
from torch.autograd import Variable
import numpy as np
try:
  bool(type(unicode))
except NameError:
  unicode = str
 
base_dir = 'cnews'
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
 
class TextCNN(nn.Module):
  def __init__(self):
    super(TextCNN, self).__init__()
    self.embedding = nn.Embedding(5000,64)
    self.conv = nn.Conv1d(64,256,5)
    self.f1 = nn.Sequential(nn.Linear(152576, 128),
                nn.ReLU())
    self.f2 = nn.Sequential(nn.Linear(128, 10),
                nn.Softmax())
  def forward(self, x):
    x = self.embedding(x)
    x = x.detach().numpy()
    x = np.transpose(x,[0,2,1])
    x = torch.Tensor(x)
    x = Variable(x)
    x = self.conv(x)
    x = x.view(-1,152576)
    x = self.f1(x)
    return self.f2(x)
 
class CnnModel:
  def __init__(self):
    self.categories, self.cat_to_id = read_category()
    self.words, self.word_to_id = read_vocab(vocab_dir)
    self.model = TextCNN()
    self.model.load_state_dict(torch.load('model_params.pkl'))
 
  def predict(self, message):
    # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
    content = unicode(message)
    data = [self.word_to_id[x] for x in content if x in self.word_to_id]
    data = kr.preprocessing.sequence.pad_sequences([data],600)
    data = torch.LongTensor(data)
    y_pred_cls = self.model(data)
    class_index = torch.argmax(y_pred_cls[0]).item()
    return self.categories[class_index]
 
class RnnModel:
  def __init__(self):
    self.categories, self.cat_to_id = read_category()
    self.words, self.word_to_id = read_vocab(vocab_dir)
    self.model = TextRNN()
    self.model.load_state_dict(torch.load('model_rnn_params.pkl'))
 
  def predict(self, message):
    # 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行
    content = unicode(message)
    data = [self.word_to_id[x] for x in content if x in self.word_to_id]
    data = kr.preprocessing.sequence.pad_sequences([data], 600)
    data = torch.LongTensor(data)
    y_pred_cls = self.model(data)
    class_index = torch.argmax(y_pred_cls[0]).item()
    return self.categories[class_index]
 
 
if __name__ == '__main__':
  model = CnnModel()
  # model = RnnModel()
  test_demo = ['湖人助教力助科比恢复手感 他也是阿泰的精神导师新浪体育讯记者戴高乐报道 上赛季,科比的右手食指遭遇重创,他的投篮手感也因此大受影响。不过很快科比就调整了自己的投篮手型,并通过这一方式让自己的投篮命中率回升。而在这科比背后,有一位特别助教对科比帮助很大,他就是查克·珀森。珀森上赛季担任湖人的特别助教,除了帮助科比调整投篮手型之外,他的另一个重要任务就是担任阿泰的精神导师。来到湖人队之后,阿泰收敛起了暴躁的脾气,成为湖人夺冠路上不可或缺的一员,珀森的“心灵按摩”功不可没。经历了上赛季的成功之后,珀森本赛季被“升职”成为湖人队的全职助教,每场比赛,他都会坐在球场边,帮助禅师杰克逊一起指挥湖人球员在场上拼杀。对于珀森的工作,禅师非常欣赏,“查克非常善于分析问题,”菲尔·杰克逊说,“他总是在寻找问题的答案,同时也在找造成这一问题的原因,这是我们都非常乐于看到的。我会在平时把防守中出现的一些问题交给他,然后他会通过组织球员练习找到解决的办法。他在球员时代曾是一名很好的外线投手,不过现在他与内线球员的配合也相当不错。',
         '弗老大被裁美国媒体看热闹“特权”在中国像蠢蛋弗老大要走了。虽然他只在首钢男篮效力了13天,而且表现毫无亮点,大大地让球迷和俱乐部失望了,但就像中国人常说的“好聚好散”,队友还是友好地与他告别,俱乐部与他和平分手,球迷还请他留下了在北京的最后一次签名。相比之下,弗老大的同胞美国人却没那么“宽容”。他们嘲讽这位NBA前巨星的英雄迟暮,批评他在CBA的业余表现,还惊讶于中国人的“大方”。今天,北京首钢俱乐部将与弗朗西斯继续商讨解约一事。从昨日的进展来看,双方可以做到“买卖不成人意在”,但回到美国后,恐怕等待弗朗西斯的就没有这么轻松的环境了。进展@北京昨日与队友告别 最后一次为球迷签名弗朗西斯在13天里为首钢队打了4场比赛,3场的得分为0,只有一场得了2分。昨天是他来到北京的第14天,虽然他与首钢还未正式解约,但双方都明白“缘分已尽”。下午,弗朗西斯来到首钢俱乐部与队友们告别。弗朗西斯走到队友身边,依次与他们握手拥抱。“你们都对我很好,安排的条件也很好,我很喜欢这支球队,想融入你们,但我现在真的很不适应。希望你们']
  for i in test_demo:
    print(i,":",model.predict(i))

cnews_loader.py:

# coding: utf-8
 
import sys
from collections import Counter
 
import numpy as np
import tensorflow.contrib.keras as kr
 
if sys.version_info[0] > 2:
  is_py3 = True
else:
  reload(sys)
  sys.setdefaultencoding("utf-8")
  is_py3 = False
 
 
def native_word(word, encoding='utf-8'):
  """如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""
  if not is_py3:
    return word.encode(encoding)
  else:
    return word
 
 
def native_content(content):
  if not is_py3:
    return content.decode('utf-8')
  else:
    return content
 
 
def open_file(filename, mode='r'):
  """
  常用文件操作,可在python2和python3间切换.
  mode: 'r' or 'w' for read or write
  """
  if is_py3:
    return open(filename, mode, encoding='utf-8', errors='ignore')
  else:
    return open(filename, mode)
 
 
def read_file(filename):
  """读取文件数据"""
  contents, labels = [], []
  with open_file(filename) as f:
    for line in f:
      try:
        label, content = line.strip().split('\t')
        if content:
          contents.append(list(native_content(content)))
          labels.append(native_content(label))
      except:
        pass
  return contents, labels
 
 
def build_vocab(train_dir, vocab_dir, vocab_size=5000):
  """根据训练集构建词汇表,存储"""
  data_train, _ = read_file(train_dir)
 
  all_data = []
  for content in data_train:
    all_data.extend(content)
 
  counter = Counter(all_data)
  count_pairs = counter.most_common(vocab_size - 1)
  words, _ = list(zip(*count_pairs))
  # 添加一个 <PAD> 来将所有文本pad为同一长度
  words = ['<PAD>'] + list(words)
  open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')
 
 
def read_vocab(vocab_dir):
  """读取词汇表"""
  # words = open_file(vocab_dir).read().strip().split('\n')
  with open_file(vocab_dir) as fp:
    # 如果是py2 则每个值都转化为unicode
    words = [native_content(_.strip()) for _ in fp.readlines()]
  word_to_id = dict(zip(words, range(len(words))))
  return words, word_to_id
 
 
def read_category():
  """读取分类目录,固定"""
  categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
 
  categories = [native_content(x) for x in categories]
 
  cat_to_id = dict(zip(categories, range(len(categories))))
 
  return categories, cat_to_id
 
 
def to_words(content, words):
  """将id表示的内容转换为文字"""
  return ''.join(words[x] for x in content)
 
 
def process_file(filename, word_to_id, cat_to_id, max_length=600):
  """将文件转换为id表示"""
  contents, labels = read_file(filename)#读取训练数据的每一句话及其所对应的类别
  data_id, label_id = [], []
  for i in range(len(contents)):
    data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])#将每句话id化
    label_id.append(cat_to_id[labels[i]])#每句话对应的类别的id
  #
  # # 使用keras提供的pad_sequences来将文本pad为固定长度
  x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)
  y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id)) # 将标签转换为one-hot表示
  #
  return x_pad, y_pad
 
 
def batch_iter(x, y, batch_size=64):
  """生成批次数据"""
  data_len = len(x)
  num_batch = int((data_len - 1) / batch_size) + 1
 
  indices = np.random.permutation(np.arange(data_len))
  x_shuffle = x[indices]
  y_shuffle = y[indices]
 
  for i in range(num_batch):
    start_id = i * batch_size
    end_id = min((i + 1) * batch_size, data_len)
    yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

以上这篇pytorch实现用CNN和LSTM对文本进行分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python的webrtc库实现语音端点检测
May 31 Python
python破解zip加密文件的方法
May 31 Python
python 寻找list中最大元素对应的索引方法
Jun 28 Python
python 执行文件时额外参数获取的实例
Dec 18 Python
Python将主机名转换为IP地址的方法
Aug 14 Python
Python 变量的创建过程详解
Sep 02 Python
感知器基础原理及python实现过程详解
Sep 30 Python
python打印异常信息的两种实现方式
Dec 24 Python
在TensorFlow中屏蔽warning的方式
Feb 04 Python
python字典与json转换的方法总结
Dec 28 Python
Python之qq自动发消息的示例代码
Feb 18 Python
Pygame Rect区域位置的使用(图文)
Nov 17 Python
使用pytorch和torchtext进行文本分类的实例
Jan 08 #Python
python爬虫爬取监控教务系统的思路详解
Jan 08 #Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 #Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
Jan 08 #Python
Pytorch实现神经网络的分类方式
Jan 08 #Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
You might like
十天学会php之第十天
2006/10/09 PHP
set_include_path在win和linux下的区别
2008/01/10 PHP
php实现的通用图片处理类
2015/03/24 PHP
部署PHP时的4个配置修改说明
2015/10/19 PHP
thinkPHP5实现数据库添加内容的方法
2017/10/25 PHP
对laravel in 查询的使用方法详解
2019/10/09 PHP
js如何获取file控件的完整路径具体实现代码
2013/05/15 Javascript
jtable列中自定义button示例代码
2013/11/21 Javascript
jQuery中extend函数详解
2015/02/13 Javascript
JavaScript中几种排序算法的简单实现
2015/07/29 Javascript
使用coffeescript编写node.js项目的方法汇总
2015/08/05 Javascript
jQuery实现垂直半透明手风琴特效代码分享
2015/08/21 Javascript
js实现文字向上轮播功能
2017/01/13 Javascript
原生js实现电商侧边导航效果
2017/01/19 Javascript
JS简单实现获取元素的封装操作示例
2017/04/07 Javascript
jQuery实现监听下拉框选中内容发生改变操作示例
2018/07/13 jQuery
Vue动态路由缓存不相互影响的解决办法
2019/02/19 Javascript
微信小程序左滑删除实现代码实例
2019/09/16 Javascript
[02:30]DOTA2英雄基础教程 暗影恶魔
2013/12/17 DOTA
简洁的十分钟Python入门教程
2015/04/03 Python
使用FastCGI部署Python的Django应用的教程
2015/07/22 Python
windows下python安装paramiko模块和pycrypto模块(简单三步)
2017/07/06 Python
从DataFrame中提取出Series或DataFrame对象的方法
2018/11/10 Python
Numpy中对向量、矩阵的使用详解
2019/10/29 Python
英国家喻户晓的家居商店:The Range
2019/03/25 全球购物
MYSQL基础面试题
2012/05/13 面试题
工程业务员岗位职责
2013/12/31 职场文书
师德师风个人反思
2014/04/28 职场文书
竞选学委演讲稿
2014/09/13 职场文书
2014年反腐倡廉工作总结
2014/12/05 职场文书
2015年党日活动总结范文
2015/03/25 职场文书
2015大学迎新晚会主持词
2015/07/16 职场文书
JS中一些高效的魔法运算符总结
2021/05/06 Javascript
HTML+VUE分页实现炫酷物联网大屏功能
2021/05/27 Vue.js
redis使用不当导致应用卡死bug的过程解析
2021/07/01 Redis
解决MySQL添加新用户-ERROR 1045 (28000)的问题
2022/03/03 MySQL