使用keras框架cnn+ctc_loss识别不定长字符图片操作


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
#keras==2.0.5
#tensorflow==1.1.0

import os,sys,string
import sys
import logging
import multiprocessing
import time
import json
import cv2
import numpy as np
from sklearn.model_selection import train_test_split

import keras
import keras.backend as K
from keras.datasets import mnist
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras import backend as K
# from keras.utils.visualize_util import plot
from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

#识别字符集
char_ocr='0123456789' #string.digits
#定义识别字符串的最大长度
seq_len=8
#识别结果集合个数 0-9
label_count=len(char_ocr)+1

def get_label(filepath):
 # print(str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1])
 lab=[]
 for num in str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1]:
 lab.append(int(char_ocr.find(num)))
 if len(lab) < seq_len:
 cur_seq_len = len(lab)
 for i in range(seq_len - cur_seq_len):
  lab.append(label_count) #
 return lab

def gen_image_data(dir=r'data\train', file_list=[]):
 dir_path = dir
 for rt, dirs, files in os.walk(dir_path): # =pathDir
 for filename in files:
  # print (filename)
  if filename.find('.') >= 0:
  (shotname, extension) = os.path.splitext(filename)
  # print shotname,extension
  if extension == '.tif': # extension == '.png' or
   file_list.append(os.path.join('%s\\%s' % (rt, filename)))
   # print (filename)

 print(len(file_list))
 index = 0
 X = []
 Y = []
 for file in file_list:

 index += 1
 # if index>1000:
 # break
 # print(file)
 img = cv2.imread(file, 0)
 # print(np.shape(img))
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 img = cv2.resize(img, (150, 50), interpolation=cv2.INTER_CUBIC)
 img = cv2.transpose(img,(50,150))
 img =cv2.flip(img,1)
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 # cv2.waitKey()
 img = (255 - img) / 256 # 反色处理
 X.append([img])
 Y.append(get_label(file))
 # print(get_label(file))
 # print(np.shape(X))
 # print(np.shape(X))

 # print(np.shape(X))
 X = np.transpose(X, (0, 2, 3, 1))
 X = np.array(X)
 Y = np.array(Y)
 return X,Y

# the actual loss calc occurs here despite it not being
# an internal Keras loss function

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage:
 # y_pred = y_pred[:, 2:, :] 测试感觉没影响
 y_pred = y_pred[:, :, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

if __name__ == '__main__':
 height=150
 width=50
 input_tensor = Input((height, width, 1))
 x = input_tensor
 for i in range(3):
 x = Convolution2D(32*2**i, (3, 3), activation='relu', padding='same')(x)
 # x = Convolution2D(32*2**i, (3, 3), activation='relu')(x)
 x = MaxPooling2D(pool_size=(2, 2))(x)

 conv_shape = x.get_shape()
 # print(conv_shape)
 x = Reshape(target_shape=(int(conv_shape[1]), int(conv_shape[2] * conv_shape[3])))(x)

 x = Dense(32, activation='relu')(x)

 gru_1 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru1')(x)
 gru_1b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(x)
 gru1_merged = add([gru_1, gru_1b]) ###################

 gru_2 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)
 gru_2b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(
 gru1_merged)
 x = concatenate([gru_2, gru_2b]) ######################
 x = Dropout(0.25)(x)
 x = Dense(label_count, kernel_initializer='he_normal', activation='softmax')(x)
 base_model = Model(inputs=input_tensor, outputs=x)

 labels = Input(name='the_labels', shape=[seq_len], dtype='float32')
 input_length = Input(name='input_length', shape=[1], dtype='int64')
 label_length = Input(name='label_length', shape=[1], dtype='int64')
 loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([x, labels, input_length, label_length])

 model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])
 model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
 model.summary()

 def test(base_model):
 file_list = []
 X, Y = gen_image_data(r'data\test', file_list)
 y_pred = base_model.predict(X)
 shape = y_pred[:, :, :].shape # 2:
 out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:,
  :seq_len] # 2:
 print()
 error_count=0
 for i in range(len(X)):
  print(file_list[i])
  str_src = str(os.path.split(file_list[i])[-1]).split('.')[0].split('_')[-1]
  print(out[i])
  str_out = ''.join([str(x) for x in out[i] if x!=-1 ])
  print(str_src, str_out)
  if str_src!=str_out:
  error_count+=1
  print('################################',error_count)
  # img = cv2.imread(file_list[i])
  # cv2.imshow('image', img)
  # cv2.waitKey()

 class LossHistory(Callback):
 def on_train_begin(self, logs={}):
  self.losses = []

 def on_epoch_end(self, epoch, logs=None):
  model.save_weights('model_1018.w')
  base_model.save_weights('base_model_1018.w')
  test(base_model)

 def on_batch_end(self, batch, logs={}):
  self.losses.append(logs.get('loss'))


 # checkpointer = ModelCheckpoint(filepath="keras_seq2seq_1018.hdf5", verbose=1, save_best_only=True, )
 history = LossHistory()

 # base_model.load_weights('base_model_1018.w')
 # model.load_weights('model_1018.w')

 X,Y=gen_image_data()
 maxin=4900
 subseq_size = 100
 batch_size=10
 result=model.fit([X[:maxin], Y[:maxin], np.array(np.ones(len(X))*int(conv_shape[1]))[:maxin], np.array(np.ones(len(X))*seq_len)[:maxin]], Y[:maxin],
   batch_size=20,
   epochs=1000,
   callbacks=[history, plotter, EarlyStopping(patience=10)], #checkpointer, history,
   validation_data=([X[maxin:], Y[maxin:], np.array(np.ones(len(X))*int(conv_shape[1]))[maxin:], np.array(np.ones(len(X))*seq_len)[maxin:]], Y[maxin:]),
   )

 test(base_model)

 K.clear_session()

补充知识:日常填坑之keras.backend.ctc_batch_cost参数问题

InvalidArgumentError sequence_length(0) <=30错误

下面的代码是在网上绝大多数文章给出的关于k.ctc_batch_cost()函数的使用代码

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage: 
 y_pred = y_pred[:, 2:, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

可以注意到有一句:y_pred = y_pred[:, 2:, :],这里把y_pred 的第二维数据去掉了两列,说人话:把送进lstm序列的step减了2步。后来偶然在一篇文章中有提到说这里之所以减2是因为在将feature送入keras的lstm时自动少了2维,所以这里就写成这样了。估计是之前老版本的bug,现在的新版本已经修复了。如果依然按照上面的写法,会得到如下错误:

InvalidArgumentError sequence_length(0) <=30

'<='后面的数值 = 你cnn最后的输出维度 - 2。这个错误我找了很久,一直不明白30哪里来的,后来一行行的检查代码是发现了这里很可疑,于是改成如下形式错误解决。

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args 
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

训练时出现ctc_loss_calculator.cc:144] No valid path found或loss: inf错误

熟悉CTC算法的话,这个提示应该是ctc没找到有效路径。既然是没找到有效路径,那肯定是label和input之间哪个地方又出问题了!和input相关的错误已经解决了,那么肯定就是label的问题了。再看ctc_batch_cost的四个参数,labels和label_length这两个地方有可疑。对于ctc_batch_cost()的参数,labels需要one-hot编码,形状:[batch, max_labelLength],其中max_labelLength指预测的最大字符长度;label_length就是每个label中的字符长度了,受之前tf.ctc_loss的影响把这里都设置成了最大长度,所以报错。

对于参数labels而言,max_labelLength是能预测的最大字符长度。这个值与送lstm的featue的第二维,即特征序列的max_step有关,表面上看只要max_labelLength<max_step即可,但是如果小的不多依然会出现上述错误。至于到底要小多少,还得从ctc算法里找,由于ctc算法在标签中的每个字符后都加了一个空格,所以应该把这个长度考虑进去,所以有 max_labelLength < max_step//2。没仔细研究keras里ctc_batch_cost()函数的实现细节,上面是我的猜测。如果有很明确的答案,还请麻烦告诉我一声,谢了先!

错误代码:

batch_label_length = np.ones(batch_size) * max_labelLength

正确打开方式:

batch_x, batch_y = [], []
batch_input_length = np.ones(batch_size) * (max_img_weigth//8)
batch_label_length = []
for j in range(i, i + batch_size):
 x, y = self.get_img_data(index_all[j])
 batch_x.append(x)
 batch_y.append(y)
 batch_label_length.append(self.label_length[j])

最后附一张我的crnn的模型图:

使用keras框架cnn+ctc_loss识别不定长字符图片操作

以上这篇使用keras框架cnn+ctc_loss识别不定长字符图片操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简洁的十分钟Python入门教程
Apr 03 Python
Python实现获取磁盘剩余空间的2种方法
Jun 07 Python
wxPython之解决闪烁的问题
Jan 15 Python
python3.5绘制随机漫步图
Aug 27 Python
Django网络框架之HelloDjango项目创建教程
Jun 06 Python
python机器学习包mlxtend的安装和配置详解
Aug 21 Python
pytorch torch.expand和torch.repeat的区别详解
Nov 05 Python
基于Python快速处理PDF表格数据
Jun 03 Python
Python小白垃圾回收机制入门
Jun 09 Python
keras使用Sequence类调用大规模数据集进行训练的实现
Jun 22 Python
详解python百行有效代码实现汉诺塔小游戏(简约版)
Oct 30 Python
使用Django的JsonResponse返回数据的实现
Jan 15 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
keras K.function获取某层的输出操作
Jun 29 #Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
You might like
菜鸟修复电子管记
2021/03/02 无线电
php跨服务器访问方法小结
2015/05/12 PHP
php从文件夹随机读取文件的方法
2015/06/01 PHP
TP5框架使用QueryList采集框架爬小说操作示例
2020/03/26 PHP
js资料toString 方法
2007/03/13 Javascript
JavaScript 撑出页面文字换行
2009/06/15 Javascript
php跨域调用json的例子
2013/11/13 Javascript
js中文逗号转英文实现
2014/02/11 Javascript
js获取鼠标位置实例详解
2015/12/09 Javascript
Angualrjs和bootstrap相结合实现数据表格table
2017/03/30 Javascript
vue.js实现单选框、复选框和下拉框示例
2017/07/18 Javascript
vue 引入公共css文件的简单方法(推荐)
2018/01/20 Javascript
Angular2实现的秒表及改良版示例
2019/05/10 Javascript
vue3.0 搭建项目总结(详细步骤)
2019/05/20 Javascript
Vue 动态组件components和v-once指令的实现
2019/08/30 Javascript
python3模拟百度登录并实现百度贴吧签到示例分享(百度贴吧自动签到)
2014/02/24 Python
Python内置数据结构与操作符的练习题集锦
2016/07/01 Python
Python使用MyQR制作专属动态彩色二维码功能
2019/06/04 Python
python英语单词测试小程序代码实例
2019/09/09 Python
python统计指定目录内文件的代码行数
2019/09/19 Python
Python上下文管理器全实例详解
2019/11/12 Python
Python内置数据类型list各方法的性能测试过程解析
2020/01/07 Python
Python 实现Image和Ndarray互相转换
2020/02/19 Python
python中return不返回值的问题解析
2020/07/22 Python
Python extract及contains方法代码实例
2020/09/11 Python
翻新二手苹果产品的网络领导者:Mac of all Trades
2017/12/19 全球购物
促销活动策划方案
2014/01/12 职场文书
品质管理部岗位职责范文
2014/03/01 职场文书
汉语言文学毕业生自荐信范文
2014/03/24 职场文书
社区娱乐活动方案
2014/08/21 职场文书
酒店管理专业毕业生自我鉴定
2014/09/29 职场文书
王兆力在市委党的群众路线教育实践活动总结大会上的讲话稿
2014/10/25 职场文书
赢在执行观后感
2015/06/16 职场文书
婚庆答谢词大全
2015/09/29 职场文书
CSS3 制作的图片滚动效果
2021/04/14 HTML / CSS
浅析Python中的套接字编程
2021/06/22 Python