使用keras框架cnn+ctc_loss识别不定长字符图片操作


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
#keras==2.0.5
#tensorflow==1.1.0

import os,sys,string
import sys
import logging
import multiprocessing
import time
import json
import cv2
import numpy as np
from sklearn.model_selection import train_test_split

import keras
import keras.backend as K
from keras.datasets import mnist
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras import backend as K
# from keras.utils.visualize_util import plot
from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

#识别字符集
char_ocr='0123456789' #string.digits
#定义识别字符串的最大长度
seq_len=8
#识别结果集合个数 0-9
label_count=len(char_ocr)+1

def get_label(filepath):
 # print(str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1])
 lab=[]
 for num in str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1]:
 lab.append(int(char_ocr.find(num)))
 if len(lab) < seq_len:
 cur_seq_len = len(lab)
 for i in range(seq_len - cur_seq_len):
  lab.append(label_count) #
 return lab

def gen_image_data(dir=r'data\train', file_list=[]):
 dir_path = dir
 for rt, dirs, files in os.walk(dir_path): # =pathDir
 for filename in files:
  # print (filename)
  if filename.find('.') >= 0:
  (shotname, extension) = os.path.splitext(filename)
  # print shotname,extension
  if extension == '.tif': # extension == '.png' or
   file_list.append(os.path.join('%s\\%s' % (rt, filename)))
   # print (filename)

 print(len(file_list))
 index = 0
 X = []
 Y = []
 for file in file_list:

 index += 1
 # if index>1000:
 # break
 # print(file)
 img = cv2.imread(file, 0)
 # print(np.shape(img))
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 img = cv2.resize(img, (150, 50), interpolation=cv2.INTER_CUBIC)
 img = cv2.transpose(img,(50,150))
 img =cv2.flip(img,1)
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 # cv2.waitKey()
 img = (255 - img) / 256 # 反色处理
 X.append([img])
 Y.append(get_label(file))
 # print(get_label(file))
 # print(np.shape(X))
 # print(np.shape(X))

 # print(np.shape(X))
 X = np.transpose(X, (0, 2, 3, 1))
 X = np.array(X)
 Y = np.array(Y)
 return X,Y

# the actual loss calc occurs here despite it not being
# an internal Keras loss function

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage:
 # y_pred = y_pred[:, 2:, :] 测试感觉没影响
 y_pred = y_pred[:, :, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

if __name__ == '__main__':
 height=150
 width=50
 input_tensor = Input((height, width, 1))
 x = input_tensor
 for i in range(3):
 x = Convolution2D(32*2**i, (3, 3), activation='relu', padding='same')(x)
 # x = Convolution2D(32*2**i, (3, 3), activation='relu')(x)
 x = MaxPooling2D(pool_size=(2, 2))(x)

 conv_shape = x.get_shape()
 # print(conv_shape)
 x = Reshape(target_shape=(int(conv_shape[1]), int(conv_shape[2] * conv_shape[3])))(x)

 x = Dense(32, activation='relu')(x)

 gru_1 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru1')(x)
 gru_1b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(x)
 gru1_merged = add([gru_1, gru_1b]) ###################

 gru_2 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)
 gru_2b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(
 gru1_merged)
 x = concatenate([gru_2, gru_2b]) ######################
 x = Dropout(0.25)(x)
 x = Dense(label_count, kernel_initializer='he_normal', activation='softmax')(x)
 base_model = Model(inputs=input_tensor, outputs=x)

 labels = Input(name='the_labels', shape=[seq_len], dtype='float32')
 input_length = Input(name='input_length', shape=[1], dtype='int64')
 label_length = Input(name='label_length', shape=[1], dtype='int64')
 loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([x, labels, input_length, label_length])

 model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])
 model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
 model.summary()

 def test(base_model):
 file_list = []
 X, Y = gen_image_data(r'data\test', file_list)
 y_pred = base_model.predict(X)
 shape = y_pred[:, :, :].shape # 2:
 out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:,
  :seq_len] # 2:
 print()
 error_count=0
 for i in range(len(X)):
  print(file_list[i])
  str_src = str(os.path.split(file_list[i])[-1]).split('.')[0].split('_')[-1]
  print(out[i])
  str_out = ''.join([str(x) for x in out[i] if x!=-1 ])
  print(str_src, str_out)
  if str_src!=str_out:
  error_count+=1
  print('################################',error_count)
  # img = cv2.imread(file_list[i])
  # cv2.imshow('image', img)
  # cv2.waitKey()

 class LossHistory(Callback):
 def on_train_begin(self, logs={}):
  self.losses = []

 def on_epoch_end(self, epoch, logs=None):
  model.save_weights('model_1018.w')
  base_model.save_weights('base_model_1018.w')
  test(base_model)

 def on_batch_end(self, batch, logs={}):
  self.losses.append(logs.get('loss'))


 # checkpointer = ModelCheckpoint(filepath="keras_seq2seq_1018.hdf5", verbose=1, save_best_only=True, )
 history = LossHistory()

 # base_model.load_weights('base_model_1018.w')
 # model.load_weights('model_1018.w')

 X,Y=gen_image_data()
 maxin=4900
 subseq_size = 100
 batch_size=10
 result=model.fit([X[:maxin], Y[:maxin], np.array(np.ones(len(X))*int(conv_shape[1]))[:maxin], np.array(np.ones(len(X))*seq_len)[:maxin]], Y[:maxin],
   batch_size=20,
   epochs=1000,
   callbacks=[history, plotter, EarlyStopping(patience=10)], #checkpointer, history,
   validation_data=([X[maxin:], Y[maxin:], np.array(np.ones(len(X))*int(conv_shape[1]))[maxin:], np.array(np.ones(len(X))*seq_len)[maxin:]], Y[maxin:]),
   )

 test(base_model)

 K.clear_session()

补充知识:日常填坑之keras.backend.ctc_batch_cost参数问题

InvalidArgumentError sequence_length(0) <=30错误

下面的代码是在网上绝大多数文章给出的关于k.ctc_batch_cost()函数的使用代码

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage: 
 y_pred = y_pred[:, 2:, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

可以注意到有一句:y_pred = y_pred[:, 2:, :],这里把y_pred 的第二维数据去掉了两列,说人话:把送进lstm序列的step减了2步。后来偶然在一篇文章中有提到说这里之所以减2是因为在将feature送入keras的lstm时自动少了2维,所以这里就写成这样了。估计是之前老版本的bug,现在的新版本已经修复了。如果依然按照上面的写法,会得到如下错误:

InvalidArgumentError sequence_length(0) <=30

'<='后面的数值 = 你cnn最后的输出维度 - 2。这个错误我找了很久,一直不明白30哪里来的,后来一行行的检查代码是发现了这里很可疑,于是改成如下形式错误解决。

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args 
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

训练时出现ctc_loss_calculator.cc:144] No valid path found或loss: inf错误

熟悉CTC算法的话,这个提示应该是ctc没找到有效路径。既然是没找到有效路径,那肯定是label和input之间哪个地方又出问题了!和input相关的错误已经解决了,那么肯定就是label的问题了。再看ctc_batch_cost的四个参数,labels和label_length这两个地方有可疑。对于ctc_batch_cost()的参数,labels需要one-hot编码,形状:[batch, max_labelLength],其中max_labelLength指预测的最大字符长度;label_length就是每个label中的字符长度了,受之前tf.ctc_loss的影响把这里都设置成了最大长度,所以报错。

对于参数labels而言,max_labelLength是能预测的最大字符长度。这个值与送lstm的featue的第二维,即特征序列的max_step有关,表面上看只要max_labelLength<max_step即可,但是如果小的不多依然会出现上述错误。至于到底要小多少,还得从ctc算法里找,由于ctc算法在标签中的每个字符后都加了一个空格,所以应该把这个长度考虑进去,所以有 max_labelLength < max_step//2。没仔细研究keras里ctc_batch_cost()函数的实现细节,上面是我的猜测。如果有很明确的答案,还请麻烦告诉我一声,谢了先!

错误代码:

batch_label_length = np.ones(batch_size) * max_labelLength

正确打开方式:

batch_x, batch_y = [], []
batch_input_length = np.ones(batch_size) * (max_img_weigth//8)
batch_label_length = []
for j in range(i, i + batch_size):
 x, y = self.get_img_data(index_all[j])
 batch_x.append(x)
 batch_y.append(y)
 batch_label_length.append(self.label_length[j])

最后附一张我的crnn的模型图:

使用keras框架cnn+ctc_loss识别不定长字符图片操作

以上这篇使用keras框架cnn+ctc_loss识别不定长字符图片操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 字符串与数字输出方法
Jul 16 Python
Python 安装第三方库 pip install 安装慢安装不上的解决办法
Jun 18 Python
用Python解数独的方法示例
Oct 24 Python
Python Numpy 自然数填充数组的实现
Nov 28 Python
使用opencv将视频帧转成图片输出
Dec 10 Python
爬虫代理池Python3WebSpider源代码测试过程解析
Dec 20 Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 Python
细数nn.BCELoss与nn.CrossEntropyLoss的区别
Feb 29 Python
Mysql数据库反向生成Django里面的models指令方式
May 18 Python
pycharm实现print输出保存到txt文件
Jun 01 Python
使用AJAX和Django获取数据的方法实例
Oct 25 Python
win10+anaconda安装yolov5的方法及问题解决方案
Apr 29 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
keras K.function获取某层的输出操作
Jun 29 #Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
You might like
php中session使用示例
2014/03/29 PHP
PHP函数addslashes和mysql_real_escape_string的区别
2014/04/22 PHP
学习php设计模式 php实现工厂模式(factory)
2015/12/07 PHP
PHP字典树(Trie树)定义与实现方法示例
2017/10/09 PHP
PHP基于mcript扩展实现对称加密功能示例
2019/02/21 PHP
php 策略模式原理与应用深入理解
2019/09/25 PHP
php写app用的框架整理
2019/09/29 PHP
点击按钮或链接不跳转只刷新页面的脚本整理
2013/10/22 Javascript
图解JavaScript中的this关键字
2020/05/28 Javascript
jQuery实现简单隔行变色的方法
2016/02/20 Javascript
JS数组操作(数组增加、删除、翻转、转字符串、取索引、截取(切片)slice、剪接splice、数组合并)
2016/05/20 Javascript
KnockoutJS 3.X API 第四章之数据控制流if绑定和ifnot绑定
2016/10/10 Javascript
JavaScript Drum Kit 指南(纯 JS 模拟敲鼓效果)
2017/07/23 Javascript
JavaScript事件处理程序详解
2017/09/19 Javascript
在Vue中使用highCharts绘制3d饼图的方法
2018/02/08 Javascript
vue兄弟组件传递数据的实例
2018/09/06 Javascript
微信小程序左滑删除功能开发案例详解
2018/11/12 Javascript
微信小程序实现动态获取元素宽高的方法分析
2018/12/10 Javascript
layui实现左侧菜单点击右侧内容区显示
2019/07/26 Javascript
vue的滚动条插件实现代码
2019/09/07 Javascript
jQuery表单校验插件validator使用方法详解
2020/02/18 jQuery
python 寻找优化使成本函数最小的最优解的方法
2017/12/28 Python
和孩子一起学习python之变量命名规则
2018/05/27 Python
python实现剪切功能
2019/01/23 Python
python实现抠图给证件照换背景源码
2019/08/20 Python
浅谈tensorflow中Dataset图片的批量读取及维度的操作详解
2020/01/20 Python
PyCharm设置注释字体颜色以及是否倾斜的操作
2020/09/16 Python
opencv python 对指针仪表读数识别的两种方式
2021/01/14 Python
全球性的女装店:storets
2019/06/12 全球购物
大学生个人简历中的自我评价
2013/12/27 职场文书
四风问题自查报告剖析材料
2014/02/08 职场文书
《童年的发现》教学反思
2014/02/14 职场文书
化工见习报告范文
2014/10/31 职场文书
2014年售后服务工作总结
2014/11/18 职场文书
班主任开场白
2015/06/01 职场文书
php TP5框架生成二维码链接
2021/04/01 PHP