使用keras实现BiLSTM+CNN+CRF文字标记NER


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

import keras
from sklearn.model_selection import train_test_split
import tensorflow as tf
from keras.callbacks import ModelCheckpoint,Callback
# import keras.backend as K
from keras.layers import *
from keras.models import Model
from keras.optimizers import SGD, RMSprop, Adagrad,Adam
from keras.models import *
from keras.metrics import *
from keras import backend as K
from keras.regularizers import *
from keras.metrics import categorical_accuracy
# from keras.regularizers import activity_l1 #通过L1正则项,使得输出更加稀疏
from keras_contrib.layers import CRF

from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

# from crf import CRFLayer,create_custom_objects

class LossHistory(Callback):
  def on_train_begin(self, logs={}):
    self.losses = []

  def on_batch_end(self, batch, logs={}):
    self.losses.append(logs.get('loss'))
# def on_epoch_end(self, epoch, logs=None):

word_input = Input(shape=(max_len,), dtype='int32', name='word_input')
word_emb = Embedding(len(char_value_dict)+2, output_dim=64, input_length=max_len, dropout=0.2, name='word_emb')(word_input)
bilstm = Bidirectional(LSTM(32, dropout_W=0.1, dropout_U=0.1, return_sequences=True))(word_emb)
bilstm_d = Dropout(0.1)(bilstm)
half_window_size = 2
paddinglayer = ZeroPadding1D(padding=half_window_size)(word_emb)
conv = Conv1D(nb_filter=50, filter_length=(2 * half_window_size + 1), border_mode='valid')(paddinglayer)
conv_d = Dropout(0.1)(conv)
dense_conv = TimeDistributed(Dense(50))(conv_d)
rnn_cnn_merge = merge([bilstm_d, dense_conv], mode='concat', concat_axis=2)
dense = TimeDistributed(Dense(class_label_count))(rnn_cnn_merge)
crf = CRF(class_label_count, sparse_target=False)
crf_output = crf(dense)
model = Model(input=[word_input], output=[crf_output])
model.compile(loss=crf.loss_function, optimizer='adam', metrics=[crf.accuracy])
model.summary()

# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
  json_file.write(model_json)

#编译模型
# model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['acc',])

# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来
checkpointer = ModelCheckpoint(filepath="bilstm_1102_k205_tf130.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=True
history = LossHistory()

history = model.fit(x_train, y_train,
          batch_size=32, epochs=500,#validation_data = ([x_test, seq_lens_test], y_test),
          callbacks=[checkpointer, history, plotter],
          verbose=1,
          validation_split=0.1,
          )

补充知识:keras训练模型使用自定义CTC损失函数,重载模型时报错解决办法

使用keras训练模型,用到了ctc损失函数,需要自定义损失函数如下:

self.ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt)

其中loss为自定义函数,使用字典{‘ctc': lambda y_true, output: output}

训练完模型后需要重载模型,如下:

from keras.models import load_model

model=load_model('final_ctc_model.h5')

报错:

Unknown loss function : <lambda>

由于是自定义的损失函数需要加参数custom_objects,这里需要定义字典{'': lambda y_true, output: output},正确代码如下:

model=load_model('final_ctc_model.h5',custom_objects={'<lambda>': lambda y_true, output: output})

可能是因为要将自己定义的loss函数加入到keras函数里

在这之前试了很多次,如果用lambda y_true, output: output定义loss

函数字典名只能是'<lambda>',不能是别的字符

如果自定义一个函数如loss_func作为loss函数如:

self.ctc_model.compile(loss=loss_func, optimizer=opt)

可以在重载时使用

am=load_model('final_ctc_model.h5',custom_objects={'loss_func': loss_func})

此时注意字典名和函数名要相同

以上这篇使用keras实现BiLSTM+CNN+CRF文字标记NER就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的一只从百度开始不断搜索的小爬虫
Aug 13 Python
Python3之文件读写操作的实例讲解
Jan 23 Python
用python代码将tiff图片存储到jpg的方法
Dec 04 Python
python 内置模块详解
Jan 01 Python
Python之循环结构
Jan 15 Python
Python Pandas数据结构简单介绍
Jul 03 Python
Win10+GPU版Pytorch1.1安装的安装步骤
Sep 27 Python
Python中Flask-RESTful编写API接口(小白入门)
Dec 11 Python
TensorBoard 计算图的可视化实现
Feb 15 Python
为什么是 Python -m
Jun 19 Python
如何用python 操作zookeeper
Dec 28 Python
python缺失值的解决方法总结
Jun 09 Python
Python建造者模式案例运行原理解析
Jun 29 #Python
解决Keras中循环使用K.ctc_decode内存不释放的问题
Jun 29 #Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
You might like
我的论坛源代码(三)
2006/10/09 PHP
php 文件状态缓存带来的问题
2008/12/14 PHP
php 用checkbox一次性删除多条记录的方法
2010/02/23 PHP
培养自己的php编码规范
2015/09/28 PHP
php 实现简单的登录功能示例【基于thinkPHP框架】
2019/12/02 PHP
js实现的切换面板实例代码
2013/06/17 Javascript
Js判断CSS文件加载完毕的具体实现
2014/01/17 Javascript
使用Jquery获取带特殊符号的ID 标签的方法
2014/04/30 Javascript
原生js实现模拟滚动条
2015/06/15 Javascript
JavaScript使表单中的内容显示在屏幕上的方法
2015/06/29 Javascript
详解Node.js模块间共享数据库连接的方法
2016/05/24 Javascript
详解React-Todos入门例子
2016/11/08 Javascript
JS实现重新加载当前页面
2016/11/29 Javascript
基于vue2框架的机器人自动回复mini-project实例代码
2017/06/13 Javascript
jquery.uploadView 实现图片预览上传功能
2017/08/10 jQuery
JavaScript实现写入文件到本地的方法【基于FileSaver.js插件】
2018/03/15 Javascript
Echarts动态加载多条折线图的实现代码
2019/05/24 Javascript
JavaScript中0、空字符串、'0'是true还是false的知识点分享
2019/09/16 Javascript
[52:03]Secret vs VG 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/20 DOTA
[01:02:47]EG vs Secret 2019国际邀请赛淘汰赛 胜者组 BO3 第一场 8.21.mp4
2020/07/19 DOTA
python3之微信文章爬虫实例讲解
2017/07/12 Python
python select.select模块通信全过程解析
2017/09/20 Python
利用Django内置的认证视图实现用户密码重置功能详解
2017/11/24 Python
Python线程同步的实现代码
2018/10/03 Python
python字符串的拼接方法总结
2019/11/18 Python
Python基于模块Paramiko实现SSHv2协议
2020/04/28 Python
记录模型训练时loss值的变化情况
2020/06/16 Python
Django serializer优化类视图的实现示例
2020/07/16 Python
英国女性化妆品收纳和家具网站:Beautify
2019/12/07 全球购物
俄罗斯鲜花递送:AMF
2020/04/24 全球购物
EJB的角色和三个对象
2015/12/31 面试题
项目开发计划书
2014/01/09 职场文书
中药学专业毕业生推荐信
2014/07/10 职场文书
教师群众路线教育实践活动个人对照检查材料
2014/11/04 职场文书
汽车销售助理岗位职责
2015/04/14 职场文书
工程竣工验收申请报告
2015/05/15 职场文书