使用keras实现BiLSTM+CNN+CRF文字标记NER


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

import keras
from sklearn.model_selection import train_test_split
import tensorflow as tf
from keras.callbacks import ModelCheckpoint,Callback
# import keras.backend as K
from keras.layers import *
from keras.models import Model
from keras.optimizers import SGD, RMSprop, Adagrad,Adam
from keras.models import *
from keras.metrics import *
from keras import backend as K
from keras.regularizers import *
from keras.metrics import categorical_accuracy
# from keras.regularizers import activity_l1 #通过L1正则项,使得输出更加稀疏
from keras_contrib.layers import CRF

from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

# from crf import CRFLayer,create_custom_objects

class LossHistory(Callback):
  def on_train_begin(self, logs={}):
    self.losses = []

  def on_batch_end(self, batch, logs={}):
    self.losses.append(logs.get('loss'))
# def on_epoch_end(self, epoch, logs=None):

word_input = Input(shape=(max_len,), dtype='int32', name='word_input')
word_emb = Embedding(len(char_value_dict)+2, output_dim=64, input_length=max_len, dropout=0.2, name='word_emb')(word_input)
bilstm = Bidirectional(LSTM(32, dropout_W=0.1, dropout_U=0.1, return_sequences=True))(word_emb)
bilstm_d = Dropout(0.1)(bilstm)
half_window_size = 2
paddinglayer = ZeroPadding1D(padding=half_window_size)(word_emb)
conv = Conv1D(nb_filter=50, filter_length=(2 * half_window_size + 1), border_mode='valid')(paddinglayer)
conv_d = Dropout(0.1)(conv)
dense_conv = TimeDistributed(Dense(50))(conv_d)
rnn_cnn_merge = merge([bilstm_d, dense_conv], mode='concat', concat_axis=2)
dense = TimeDistributed(Dense(class_label_count))(rnn_cnn_merge)
crf = CRF(class_label_count, sparse_target=False)
crf_output = crf(dense)
model = Model(input=[word_input], output=[crf_output])
model.compile(loss=crf.loss_function, optimizer='adam', metrics=[crf.accuracy])
model.summary()

# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
  json_file.write(model_json)

#编译模型
# model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['acc',])

# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来
checkpointer = ModelCheckpoint(filepath="bilstm_1102_k205_tf130.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=True
history = LossHistory()

history = model.fit(x_train, y_train,
          batch_size=32, epochs=500,#validation_data = ([x_test, seq_lens_test], y_test),
          callbacks=[checkpointer, history, plotter],
          verbose=1,
          validation_split=0.1,
          )

补充知识:keras训练模型使用自定义CTC损失函数,重载模型时报错解决办法

使用keras训练模型,用到了ctc损失函数,需要自定义损失函数如下:

self.ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt)

其中loss为自定义函数,使用字典{‘ctc': lambda y_true, output: output}

训练完模型后需要重载模型,如下:

from keras.models import load_model

model=load_model('final_ctc_model.h5')

报错:

Unknown loss function : <lambda>

由于是自定义的损失函数需要加参数custom_objects,这里需要定义字典{'': lambda y_true, output: output},正确代码如下:

model=load_model('final_ctc_model.h5',custom_objects={'<lambda>': lambda y_true, output: output})

可能是因为要将自己定义的loss函数加入到keras函数里

在这之前试了很多次,如果用lambda y_true, output: output定义loss

函数字典名只能是'<lambda>',不能是别的字符

如果自定义一个函数如loss_func作为loss函数如:

self.ctc_model.compile(loss=loss_func, optimizer=opt)

可以在重载时使用

am=load_model('final_ctc_model.h5',custom_objects={'loss_func': loss_func})

此时注意字典名和函数名要相同

以上这篇使用keras实现BiLSTM+CNN+CRF文字标记NER就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的迭代器和生成器使用实例
Jan 14 Python
Python中的条件判断语句基础学习教程
Feb 07 Python
python动态网页批量爬取
Feb 14 Python
Python+树莓派+YOLO打造一款人工智能照相机
Jan 02 Python
浅谈django的render函数的参数问题
Oct 16 Python
python基于递归解决背包问题详解
Jul 03 Python
Django 实现图片上传和显示过程详解
Jul 18 Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 Python
python import 上级目录的导入
Nov 03 Python
python 下载文件的多种方法汇总
Nov 17 Python
python spilt()分隔字符串的实现示例
May 21 Python
Python开发工具Pycharm的安装以及使用步骤总结
Jun 24 Python
Python建造者模式案例运行原理解析
Jun 29 #Python
解决Keras中循环使用K.ctc_decode内存不释放的问题
Jun 29 #Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
You might like
smarty实例教程
2006/11/19 PHP
PHP与MySQL开发中页面乱码的产生与解决
2008/03/27 PHP
PHP pathinfo()获得文件的路径、名称等信息说明
2011/09/13 PHP
解析php中static,const与define的使用区别
2013/06/18 PHP
php使用cookie保存登录用户名的方法
2015/01/26 PHP
PHP unset函数原理及使用方法解析
2020/08/14 PHP
提高代码性能技巧谈—以创建千行表格为例
2006/07/01 Javascript
地震发生中逃生十大法则
2008/05/12 Javascript
JQuery对checkbox操作 (循环获取)
2011/05/20 Javascript
JavaScript学习笔记(二) js对象
2011/10/25 Javascript
JavaScript中操作Mysql数据库实例
2015/04/02 Javascript
jquery实现浮动的侧栏实例
2015/06/25 Javascript
JavaScript常用判断写法大全(推荐)
2016/05/30 Javascript
轻松掌握JavaScript享元模式
2016/08/27 Javascript
模板视图和AngularJS之间冲突的解决方法
2016/11/22 Javascript
jQuery实现可移动选项的左右下拉列表示例
2016/12/26 Javascript
JavaScript 事件流、事件处理程序及事件对象总结
2017/04/01 Javascript
使用vue实现简单键盘的示例(支持移动端和pc端)
2017/12/25 Javascript
JS 实现分页打印功能
2018/05/16 Javascript
Angular4 组件通讯方法大全(推荐)
2018/07/12 Javascript
微信实现自动跳转到用其他浏览器打开指定APP下载
2019/02/15 Javascript
Vue的状态管理vuex使用方法详解
2020/02/05 Javascript
在vs code 中如何创建一个自己的 Vue 模板代码
2020/11/10 Javascript
javascript实现倒计时关闭广告
2021/02/09 Javascript
使用50行Python代码从零开始实现一个AI平衡小游戏
2018/11/21 Python
python 读取鼠标点击坐标的实例
2018/12/29 Python
对python mayavi三维绘图的实现详解
2019/01/08 Python
在tensorflow中实现屏蔽输出的log信息
2020/02/04 Python
Django Serializer HiddenField隐藏字段实例
2020/03/31 Python
如何利用Python给自己的头像加一个小国旗(小月饼)
2020/10/02 Python
中级会计职业生涯规划书
2014/03/01 职场文书
公司新年寄语
2014/04/04 职场文书
教师见习期自我鉴定
2014/04/28 职场文书
民用住房租房协议书
2014/10/29 职场文书
幼儿园教师暑期培训心得体会
2016/01/09 职场文书
CSS 鼠标选中文字后改变背景色的实现代码
2023/05/21 HTML / CSS