使用keras实现BiLSTM+CNN+CRF文字标记NER


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

import keras
from sklearn.model_selection import train_test_split
import tensorflow as tf
from keras.callbacks import ModelCheckpoint,Callback
# import keras.backend as K
from keras.layers import *
from keras.models import Model
from keras.optimizers import SGD, RMSprop, Adagrad,Adam
from keras.models import *
from keras.metrics import *
from keras import backend as K
from keras.regularizers import *
from keras.metrics import categorical_accuracy
# from keras.regularizers import activity_l1 #通过L1正则项,使得输出更加稀疏
from keras_contrib.layers import CRF

from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

# from crf import CRFLayer,create_custom_objects

class LossHistory(Callback):
  def on_train_begin(self, logs={}):
    self.losses = []

  def on_batch_end(self, batch, logs={}):
    self.losses.append(logs.get('loss'))
# def on_epoch_end(self, epoch, logs=None):

word_input = Input(shape=(max_len,), dtype='int32', name='word_input')
word_emb = Embedding(len(char_value_dict)+2, output_dim=64, input_length=max_len, dropout=0.2, name='word_emb')(word_input)
bilstm = Bidirectional(LSTM(32, dropout_W=0.1, dropout_U=0.1, return_sequences=True))(word_emb)
bilstm_d = Dropout(0.1)(bilstm)
half_window_size = 2
paddinglayer = ZeroPadding1D(padding=half_window_size)(word_emb)
conv = Conv1D(nb_filter=50, filter_length=(2 * half_window_size + 1), border_mode='valid')(paddinglayer)
conv_d = Dropout(0.1)(conv)
dense_conv = TimeDistributed(Dense(50))(conv_d)
rnn_cnn_merge = merge([bilstm_d, dense_conv], mode='concat', concat_axis=2)
dense = TimeDistributed(Dense(class_label_count))(rnn_cnn_merge)
crf = CRF(class_label_count, sparse_target=False)
crf_output = crf(dense)
model = Model(input=[word_input], output=[crf_output])
model.compile(loss=crf.loss_function, optimizer='adam', metrics=[crf.accuracy])
model.summary()

# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
  json_file.write(model_json)

#编译模型
# model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['acc',])

# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来
checkpointer = ModelCheckpoint(filepath="bilstm_1102_k205_tf130.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=True
history = LossHistory()

history = model.fit(x_train, y_train,
          batch_size=32, epochs=500,#validation_data = ([x_test, seq_lens_test], y_test),
          callbacks=[checkpointer, history, plotter],
          verbose=1,
          validation_split=0.1,
          )

补充知识:keras训练模型使用自定义CTC损失函数,重载模型时报错解决办法

使用keras训练模型,用到了ctc损失函数,需要自定义损失函数如下:

self.ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt)

其中loss为自定义函数,使用字典{‘ctc': lambda y_true, output: output}

训练完模型后需要重载模型,如下:

from keras.models import load_model

model=load_model('final_ctc_model.h5')

报错:

Unknown loss function : <lambda>

由于是自定义的损失函数需要加参数custom_objects,这里需要定义字典{'': lambda y_true, output: output},正确代码如下:

model=load_model('final_ctc_model.h5',custom_objects={'<lambda>': lambda y_true, output: output})

可能是因为要将自己定义的loss函数加入到keras函数里

在这之前试了很多次,如果用lambda y_true, output: output定义loss

函数字典名只能是'<lambda>',不能是别的字符

如果自定义一个函数如loss_func作为loss函数如:

self.ctc_model.compile(loss=loss_func, optimizer=opt)

可以在重载时使用

am=load_model('final_ctc_model.h5',custom_objects={'loss_func': loss_func})

此时注意字典名和函数名要相同

以上这篇使用keras实现BiLSTM+CNN+CRF文字标记NER就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随机生成彩票号码的方法
Mar 05 Python
Python发送form-data请求及拼接form-data内容的方法
Mar 05 Python
CentOS 7下安装Python 3.5并与Python2.7兼容并存详解
Jul 07 Python
Flask 让jsonify返回的json串支持中文显示的方法
Mar 26 Python
TensorFlow 模型载入方法汇总(小结)
Jun 19 Python
详解Django中间件的5种自定义方法
Jul 26 Python
Tensorflow的梯度异步更新示例
Jan 23 Python
python安装dlib库报错问题及解决方法
Mar 16 Python
python开发实例之python使用Websocket库开发简单聊天工具实例详解(python+Websocket+JS)
Mar 18 Python
使用Python webdriver图书馆抢座自动预约的正确方法
Mar 04 Python
在pycharm中无法import所安装的库解决方案
May 31 Python
聊聊Python中关于a=[[]]*3的反思
Jun 02 Python
Python建造者模式案例运行原理解析
Jun 29 #Python
解决Keras中循环使用K.ctc_decode内存不释放的问题
Jun 29 #Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
You might like
php开发工具之vs2005图解
2008/01/12 PHP
解析dedecms空间迁移步骤详解
2013/05/15 PHP
PHP5.5在windows安装使用memcached服务端的方法
2014/04/16 PHP
php判断str字符串是否是xml格式数据的方法示例
2017/07/26 PHP
js中eval详解
2012/03/30 Javascript
artDialog 4.1.5 Dreamweaver代码提示/补全插件 附下载
2012/07/31 Javascript
JS动态添加option和删除option(附实例代码)
2013/04/01 Javascript
js设置cookie过期及清除浏览器对应名称的cookie
2013/10/24 Javascript
各种常用的JS函数整理
2013/10/25 Javascript
JavaScript新窗口与子窗口传值详解
2014/02/11 Javascript
JavaScript两种跨域技术全面介绍
2014/04/16 Javascript
深入分析js的冒泡事件
2014/12/05 Javascript
jQuery实现监控页面所有ajax请求的方法
2015/12/10 Javascript
JS判断日期格式是否合法的简单实例
2016/07/11 Javascript
AngularJs 指令详解及示例代码
2016/09/01 Javascript
JS检测数组类型的方法小结
2017/03/14 Javascript
使用jQuery和ajax代替iframe的方法(详解)
2017/04/12 jQuery
微信小程序websocket实现聊天功能
2020/03/30 Javascript
vue生成文件本地打开查看效果的实例
2018/09/06 Javascript
小程序点击图片实现png转jpg
2019/10/22 Javascript
Python ZipFile模块详解
2013/11/01 Python
python抓取某汽车网数据解析html存入excel示例
2013/12/04 Python
python监控网卡流量并使用graphite绘图的示例
2014/04/27 Python
Python进程间通信 multiProcessing Queue队列实现详解
2019/09/23 Python
python raise的基本使用
2020/09/10 Python
吃透移动端 Html5 响应式布局
2019/12/16 HTML / CSS
香蕉共和国Banana Republic官网:美国GAP旗下偏贵族风格服饰品牌
2016/11/21 全球购物
欧洲有机婴儿食品最大的市场:Organic Baby Food(供美国和加拿大)
2018/03/28 全球购物
GWT都有什么特性
2016/12/02 面试题
庆元旦活动总结
2014/07/09 职场文书
医院我们的节日活动实施方案
2014/08/22 职场文书
护士年终个人总结
2015/02/13 职场文书
2015年乡镇纪委工作总结
2015/05/26 职场文书
党员转正党支部意见
2015/06/02 职场文书
《一面五星红旗》教学反思
2016/02/23 职场文书
劳动合同变更协议书范本
2019/04/18 职场文书