使用keras实现BiLSTM+CNN+CRF文字标记NER


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

import keras
from sklearn.model_selection import train_test_split
import tensorflow as tf
from keras.callbacks import ModelCheckpoint,Callback
# import keras.backend as K
from keras.layers import *
from keras.models import Model
from keras.optimizers import SGD, RMSprop, Adagrad,Adam
from keras.models import *
from keras.metrics import *
from keras import backend as K
from keras.regularizers import *
from keras.metrics import categorical_accuracy
# from keras.regularizers import activity_l1 #通过L1正则项,使得输出更加稀疏
from keras_contrib.layers import CRF

from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

# from crf import CRFLayer,create_custom_objects

class LossHistory(Callback):
  def on_train_begin(self, logs={}):
    self.losses = []

  def on_batch_end(self, batch, logs={}):
    self.losses.append(logs.get('loss'))
# def on_epoch_end(self, epoch, logs=None):

word_input = Input(shape=(max_len,), dtype='int32', name='word_input')
word_emb = Embedding(len(char_value_dict)+2, output_dim=64, input_length=max_len, dropout=0.2, name='word_emb')(word_input)
bilstm = Bidirectional(LSTM(32, dropout_W=0.1, dropout_U=0.1, return_sequences=True))(word_emb)
bilstm_d = Dropout(0.1)(bilstm)
half_window_size = 2
paddinglayer = ZeroPadding1D(padding=half_window_size)(word_emb)
conv = Conv1D(nb_filter=50, filter_length=(2 * half_window_size + 1), border_mode='valid')(paddinglayer)
conv_d = Dropout(0.1)(conv)
dense_conv = TimeDistributed(Dense(50))(conv_d)
rnn_cnn_merge = merge([bilstm_d, dense_conv], mode='concat', concat_axis=2)
dense = TimeDistributed(Dense(class_label_count))(rnn_cnn_merge)
crf = CRF(class_label_count, sparse_target=False)
crf_output = crf(dense)
model = Model(input=[word_input], output=[crf_output])
model.compile(loss=crf.loss_function, optimizer='adam', metrics=[crf.accuracy])
model.summary()

# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
  json_file.write(model_json)

#编译模型
# model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['acc',])

# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来
checkpointer = ModelCheckpoint(filepath="bilstm_1102_k205_tf130.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=True
history = LossHistory()

history = model.fit(x_train, y_train,
          batch_size=32, epochs=500,#validation_data = ([x_test, seq_lens_test], y_test),
          callbacks=[checkpointer, history, plotter],
          verbose=1,
          validation_split=0.1,
          )

补充知识:keras训练模型使用自定义CTC损失函数,重载模型时报错解决办法

使用keras训练模型,用到了ctc损失函数,需要自定义损失函数如下:

self.ctc_model.compile(loss={'ctc': lambda y_true, output: output}, optimizer=opt)

其中loss为自定义函数,使用字典{‘ctc': lambda y_true, output: output}

训练完模型后需要重载模型,如下:

from keras.models import load_model

model=load_model('final_ctc_model.h5')

报错:

Unknown loss function : <lambda>

由于是自定义的损失函数需要加参数custom_objects,这里需要定义字典{'': lambda y_true, output: output},正确代码如下:

model=load_model('final_ctc_model.h5',custom_objects={'<lambda>': lambda y_true, output: output})

可能是因为要将自己定义的loss函数加入到keras函数里

在这之前试了很多次,如果用lambda y_true, output: output定义loss

函数字典名只能是'<lambda>',不能是别的字符

如果自定义一个函数如loss_func作为loss函数如:

self.ctc_model.compile(loss=loss_func, optimizer=opt)

可以在重载时使用

am=load_model('final_ctc_model.h5',custom_objects={'loss_func': loss_func})

此时注意字典名和函数名要相同

以上这篇使用keras实现BiLSTM+CNN+CRF文字标记NER就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python的Django框架中使用通用视图的方法
Jul 21 Python
python利用MethodType绑定方法到类示例代码
Aug 27 Python
Python入门之三角函数atan2()函数详解
Nov 08 Python
Python实现简易Web爬虫详解
Jan 03 Python
python版学生管理系统
Jan 10 Python
python 自定义异常和异常捕捉的方法
Oct 18 Python
python实现给微信指定好友定时发送消息
Apr 29 Python
Python selenium 自动化脚本打包成一个exe文件(推荐)
Jan 14 Python
Python并发请求下限制QPS(每秒查询率)的实现代码
Jun 05 Python
使用keras实现Precise, Recall, F1-socre方式
Jun 15 Python
Python激活Anaconda环境变量的详细步骤
Jun 08 Python
Python识别花卉种类鉴定网络热门植物并自动整理分类
Apr 08 Python
Python建造者模式案例运行原理解析
Jun 29 #Python
解决Keras中循环使用K.ctc_decode内存不释放的问题
Jun 29 #Python
Python根据指定文件生成XML的方法
Jun 29 #Python
keras在构建LSTM模型时对变长序列的处理操作
Jun 29 #Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 #Python
使用keras框架cnn+ctc_loss识别不定长字符图片操作
Jun 29 #Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
You might like
虫族 ZERG 概述
2020/03/14 星际争霸
php中计算时间差的几种方法
2009/12/31 PHP
PHP 透明水印生成代码
2012/08/27 PHP
解析php时间戳与日期的转换
2013/06/06 PHP
一个严格的PHP Session会话超时时间设置方法
2014/06/10 PHP
PHP file_get_contents函数读取远程数据超时的解决方法
2015/05/13 PHP
php mysql操作mysql_connect连接数据库实例详解
2016/12/26 PHP
Yii2选项卡的简单使用
2017/05/26 PHP
屏蔽Flash右键信息的js代码
2010/01/17 Javascript
js中更短的 Array 类型转换
2011/10/30 Javascript
js 操作select和option常用代码整理
2012/12/13 Javascript
使用JQuery库提供的扩展功能实现自定义方法
2014/09/09 Javascript
js实现从数组里随机获取元素
2015/01/12 Javascript
Javascript基础教程之数据类型 (布尔型 Boolean)
2015/01/18 Javascript
javascript中函数作为参数调用的方法
2015/02/09 Javascript
AngularJS 服务详细讲解及示例代码
2016/08/17 Javascript
AngularJS入门教程之静态模板详解
2016/08/18 Javascript
bootstrap suggest搜索建议插件使用详解
2017/03/25 Javascript
详解vue-cli与webpack结合如何处理静态资源
2017/09/19 Javascript
JavaScript轮播停留效果的实现思路
2018/05/24 Javascript
微信小程序实现左侧滑动导航栏
2020/04/08 Javascript
基于Web Audio API实现音频可视化效果
2020/06/12 Javascript
探究一道价值25k的蚂蚁金服异步串行面试题
2020/08/21 Javascript
Python基础入门之seed()方法的使用
2015/05/15 Python
让Python更加充分的使用Sqlite3
2017/12/11 Python
Python日期时间对象转换为字符串的实例
2018/06/22 Python
解决python3捕获cx_oracle抛出的异常错误问题
2018/10/18 Python
Python面向对象之类的内置attr属性示例
2018/12/14 Python
python隐藏终端执行cmd命令的方法
2019/06/24 Python
python、PyTorch图像读取与numpy转换实例
2020/01/13 Python
python编写一个会算账的脚本的示例代码
2020/06/02 Python
Python命令行参数argv和argparse该如何使用
2021/02/08 Python
日本著名化妆品零售网站:Cosme Land
2019/03/01 全球购物
Perfume’s Club中文官网:西班牙美妆在线零售品牌
2020/08/24 全球购物
餐厅考勤管理制度
2014/01/28 职场文书
三八活动策划方案
2014/08/17 职场文书