使用keras框架cnn+ctc_loss识别不定长字符图片操作


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
#keras==2.0.5
#tensorflow==1.1.0

import os,sys,string
import sys
import logging
import multiprocessing
import time
import json
import cv2
import numpy as np
from sklearn.model_selection import train_test_split

import keras
import keras.backend as K
from keras.datasets import mnist
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras import backend as K
# from keras.utils.visualize_util import plot
from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

#识别字符集
char_ocr='0123456789' #string.digits
#定义识别字符串的最大长度
seq_len=8
#识别结果集合个数 0-9
label_count=len(char_ocr)+1

def get_label(filepath):
 # print(str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1])
 lab=[]
 for num in str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1]:
 lab.append(int(char_ocr.find(num)))
 if len(lab) < seq_len:
 cur_seq_len = len(lab)
 for i in range(seq_len - cur_seq_len):
  lab.append(label_count) #
 return lab

def gen_image_data(dir=r'data\train', file_list=[]):
 dir_path = dir
 for rt, dirs, files in os.walk(dir_path): # =pathDir
 for filename in files:
  # print (filename)
  if filename.find('.') >= 0:
  (shotname, extension) = os.path.splitext(filename)
  # print shotname,extension
  if extension == '.tif': # extension == '.png' or
   file_list.append(os.path.join('%s\\%s' % (rt, filename)))
   # print (filename)

 print(len(file_list))
 index = 0
 X = []
 Y = []
 for file in file_list:

 index += 1
 # if index>1000:
 # break
 # print(file)
 img = cv2.imread(file, 0)
 # print(np.shape(img))
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 img = cv2.resize(img, (150, 50), interpolation=cv2.INTER_CUBIC)
 img = cv2.transpose(img,(50,150))
 img =cv2.flip(img,1)
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 # cv2.waitKey()
 img = (255 - img) / 256 # 反色处理
 X.append([img])
 Y.append(get_label(file))
 # print(get_label(file))
 # print(np.shape(X))
 # print(np.shape(X))

 # print(np.shape(X))
 X = np.transpose(X, (0, 2, 3, 1))
 X = np.array(X)
 Y = np.array(Y)
 return X,Y

# the actual loss calc occurs here despite it not being
# an internal Keras loss function

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage:
 # y_pred = y_pred[:, 2:, :] 测试感觉没影响
 y_pred = y_pred[:, :, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

if __name__ == '__main__':
 height=150
 width=50
 input_tensor = Input((height, width, 1))
 x = input_tensor
 for i in range(3):
 x = Convolution2D(32*2**i, (3, 3), activation='relu', padding='same')(x)
 # x = Convolution2D(32*2**i, (3, 3), activation='relu')(x)
 x = MaxPooling2D(pool_size=(2, 2))(x)

 conv_shape = x.get_shape()
 # print(conv_shape)
 x = Reshape(target_shape=(int(conv_shape[1]), int(conv_shape[2] * conv_shape[3])))(x)

 x = Dense(32, activation='relu')(x)

 gru_1 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru1')(x)
 gru_1b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(x)
 gru1_merged = add([gru_1, gru_1b]) ###################

 gru_2 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)
 gru_2b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(
 gru1_merged)
 x = concatenate([gru_2, gru_2b]) ######################
 x = Dropout(0.25)(x)
 x = Dense(label_count, kernel_initializer='he_normal', activation='softmax')(x)
 base_model = Model(inputs=input_tensor, outputs=x)

 labels = Input(name='the_labels', shape=[seq_len], dtype='float32')
 input_length = Input(name='input_length', shape=[1], dtype='int64')
 label_length = Input(name='label_length', shape=[1], dtype='int64')
 loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([x, labels, input_length, label_length])

 model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])
 model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
 model.summary()

 def test(base_model):
 file_list = []
 X, Y = gen_image_data(r'data\test', file_list)
 y_pred = base_model.predict(X)
 shape = y_pred[:, :, :].shape # 2:
 out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:,
  :seq_len] # 2:
 print()
 error_count=0
 for i in range(len(X)):
  print(file_list[i])
  str_src = str(os.path.split(file_list[i])[-1]).split('.')[0].split('_')[-1]
  print(out[i])
  str_out = ''.join([str(x) for x in out[i] if x!=-1 ])
  print(str_src, str_out)
  if str_src!=str_out:
  error_count+=1
  print('################################',error_count)
  # img = cv2.imread(file_list[i])
  # cv2.imshow('image', img)
  # cv2.waitKey()

 class LossHistory(Callback):
 def on_train_begin(self, logs={}):
  self.losses = []

 def on_epoch_end(self, epoch, logs=None):
  model.save_weights('model_1018.w')
  base_model.save_weights('base_model_1018.w')
  test(base_model)

 def on_batch_end(self, batch, logs={}):
  self.losses.append(logs.get('loss'))


 # checkpointer = ModelCheckpoint(filepath="keras_seq2seq_1018.hdf5", verbose=1, save_best_only=True, )
 history = LossHistory()

 # base_model.load_weights('base_model_1018.w')
 # model.load_weights('model_1018.w')

 X,Y=gen_image_data()
 maxin=4900
 subseq_size = 100
 batch_size=10
 result=model.fit([X[:maxin], Y[:maxin], np.array(np.ones(len(X))*int(conv_shape[1]))[:maxin], np.array(np.ones(len(X))*seq_len)[:maxin]], Y[:maxin],
   batch_size=20,
   epochs=1000,
   callbacks=[history, plotter, EarlyStopping(patience=10)], #checkpointer, history,
   validation_data=([X[maxin:], Y[maxin:], np.array(np.ones(len(X))*int(conv_shape[1]))[maxin:], np.array(np.ones(len(X))*seq_len)[maxin:]], Y[maxin:]),
   )

 test(base_model)

 K.clear_session()

补充知识:日常填坑之keras.backend.ctc_batch_cost参数问题

InvalidArgumentError sequence_length(0) <=30错误

下面的代码是在网上绝大多数文章给出的关于k.ctc_batch_cost()函数的使用代码

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage: 
 y_pred = y_pred[:, 2:, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

可以注意到有一句:y_pred = y_pred[:, 2:, :],这里把y_pred 的第二维数据去掉了两列,说人话:把送进lstm序列的step减了2步。后来偶然在一篇文章中有提到说这里之所以减2是因为在将feature送入keras的lstm时自动少了2维,所以这里就写成这样了。估计是之前老版本的bug,现在的新版本已经修复了。如果依然按照上面的写法,会得到如下错误:

InvalidArgumentError sequence_length(0) <=30

'<='后面的数值 = 你cnn最后的输出维度 - 2。这个错误我找了很久,一直不明白30哪里来的,后来一行行的检查代码是发现了这里很可疑,于是改成如下形式错误解决。

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args 
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

训练时出现ctc_loss_calculator.cc:144] No valid path found或loss: inf错误

熟悉CTC算法的话,这个提示应该是ctc没找到有效路径。既然是没找到有效路径,那肯定是label和input之间哪个地方又出问题了!和input相关的错误已经解决了,那么肯定就是label的问题了。再看ctc_batch_cost的四个参数,labels和label_length这两个地方有可疑。对于ctc_batch_cost()的参数,labels需要one-hot编码,形状:[batch, max_labelLength],其中max_labelLength指预测的最大字符长度;label_length就是每个label中的字符长度了,受之前tf.ctc_loss的影响把这里都设置成了最大长度,所以报错。

对于参数labels而言,max_labelLength是能预测的最大字符长度。这个值与送lstm的featue的第二维,即特征序列的max_step有关,表面上看只要max_labelLength<max_step即可,但是如果小的不多依然会出现上述错误。至于到底要小多少,还得从ctc算法里找,由于ctc算法在标签中的每个字符后都加了一个空格,所以应该把这个长度考虑进去,所以有 max_labelLength < max_step//2。没仔细研究keras里ctc_batch_cost()函数的实现细节,上面是我的猜测。如果有很明确的答案,还请麻烦告诉我一声,谢了先!

错误代码:

batch_label_length = np.ones(batch_size) * max_labelLength

正确打开方式:

batch_x, batch_y = [], []
batch_input_length = np.ones(batch_size) * (max_img_weigth//8)
batch_label_length = []
for j in range(i, i + batch_size):
 x, y = self.get_img_data(index_all[j])
 batch_x.append(x)
 batch_y.append(y)
 batch_label_length.append(self.label_length[j])

最后附一张我的crnn的模型图:

使用keras框架cnn+ctc_loss识别不定长字符图片操作

以上这篇使用keras框架cnn+ctc_loss识别不定长字符图片操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序设计入门(3)数组的使用
Jun 16 Python
将Python代码嵌入C++程序进行编写的实例
Jul 31 Python
Centos7 Python3下安装scrapy的详细步骤
Mar 15 Python
Pandas标记删除重复记录的方法
Apr 08 Python
PyTorch上实现卷积神经网络CNN的方法
Apr 28 Python
softmax及python实现过程解析
Sep 30 Python
全网首秀之Pycharm十大实用技巧(推荐)
Apr 27 Python
pandas DataFrame运算的实现
Jun 14 Python
Django Form设置文本框为readonly操作
Jul 03 Python
Python使用xlrd实现读取合并单元格
Jul 09 Python
详细分析Python collections工具库
Jul 16 Python
python实现定时发送邮件到指定邮箱
Dec 23 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
keras K.function获取某层的输出操作
Jun 29 #Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
You might like
PHP开发文件系统实例讲解
2006/10/09 PHP
PHP在获取指定目录下的目录,在获取的目录下面再创建文件,多平台
2011/08/03 PHP
phpize的深入理解
2013/06/03 PHP
php smarty truncate UTF8乱码问题解决办法
2014/06/13 PHP
详解PHP的Laravel框架中Eloquent对象关系映射使用
2016/02/26 PHP
PJ Blog修改-禁止复制的代码和方法
2006/10/25 Javascript
JavaScript中的new的使用方法与注意事项
2007/05/16 Javascript
JavaScript高级程序设计 阅读笔记(十三) js定义类或对象
2012/08/14 Javascript
javascript获取作用在元素上面的样式属性代码
2012/09/20 Javascript
jquery数据验证插件(自制,简单,练手)实例代码
2013/10/24 Javascript
深入理解JavaScript中的预解析
2017/01/04 Javascript
详解本地Node.js服务器作为api服务器的解决办法
2017/02/28 Javascript
Vue2.0权限树组件实现代码
2017/08/29 Javascript
AngularJS实现的select二级联动下拉菜单功能示例
2017/10/25 Javascript
基于jQuery实现的设置文本区域的光标位置
2018/06/15 jQuery
javascript实现小型区块链功能
2019/04/03 Javascript
如何使用jQuery操作Cookies方法解析
2020/09/08 jQuery
element-ui中el-upload多文件一次性上传的实现
2020/12/02 Javascript
在Django的URLconf中进行函数导入的方法
2015/07/18 Python
python 保存float类型的小数的位数方法
2018/10/17 Python
Python线程之定位与销毁的实现
2019/02/17 Python
让IE6、IE7、IE8支持CSS3的脚本
2010/07/20 HTML / CSS
家得宝官网:The Home Depot(全球最大的家居装饰专业零售商)
2018/12/17 全球购物
英国领先的餐饮折扣俱乐部:Gourmet Society
2020/07/26 全球购物
总务岗位职责
2013/11/19 职场文书
大学生的网上创业计划书
2013/12/31 职场文书
应届毕业生应聘自荐信范文
2014/02/26 职场文书
市场部经理岗位职责
2014/04/10 职场文书
三年级小学生评语
2014/04/22 职场文书
民族团结先进集体事迹材料
2014/05/22 职场文书
幼儿园国庆节活动总结
2015/03/23 职场文书
学雷锋献爱心活动总结
2015/05/11 职场文书
幼儿园保育员随笔
2015/08/14 职场文书
2016基督教会圣诞节开幕词
2016/03/04 职场文书
2019最新版劳务派遣管理制度
2019/08/16 职场文书
vue 把二维或多维数组转一维数组
2022/04/24 Vue.js