使用keras框架cnn+ctc_loss识别不定长字符图片操作


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
#keras==2.0.5
#tensorflow==1.1.0

import os,sys,string
import sys
import logging
import multiprocessing
import time
import json
import cv2
import numpy as np
from sklearn.model_selection import train_test_split

import keras
import keras.backend as K
from keras.datasets import mnist
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras import backend as K
# from keras.utils.visualize_util import plot
from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

#识别字符集
char_ocr='0123456789' #string.digits
#定义识别字符串的最大长度
seq_len=8
#识别结果集合个数 0-9
label_count=len(char_ocr)+1

def get_label(filepath):
 # print(str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1])
 lab=[]
 for num in str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1]:
 lab.append(int(char_ocr.find(num)))
 if len(lab) < seq_len:
 cur_seq_len = len(lab)
 for i in range(seq_len - cur_seq_len):
  lab.append(label_count) #
 return lab

def gen_image_data(dir=r'data\train', file_list=[]):
 dir_path = dir
 for rt, dirs, files in os.walk(dir_path): # =pathDir
 for filename in files:
  # print (filename)
  if filename.find('.') >= 0:
  (shotname, extension) = os.path.splitext(filename)
  # print shotname,extension
  if extension == '.tif': # extension == '.png' or
   file_list.append(os.path.join('%s\\%s' % (rt, filename)))
   # print (filename)

 print(len(file_list))
 index = 0
 X = []
 Y = []
 for file in file_list:

 index += 1
 # if index>1000:
 # break
 # print(file)
 img = cv2.imread(file, 0)
 # print(np.shape(img))
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 img = cv2.resize(img, (150, 50), interpolation=cv2.INTER_CUBIC)
 img = cv2.transpose(img,(50,150))
 img =cv2.flip(img,1)
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 # cv2.waitKey()
 img = (255 - img) / 256 # 反色处理
 X.append([img])
 Y.append(get_label(file))
 # print(get_label(file))
 # print(np.shape(X))
 # print(np.shape(X))

 # print(np.shape(X))
 X = np.transpose(X, (0, 2, 3, 1))
 X = np.array(X)
 Y = np.array(Y)
 return X,Y

# the actual loss calc occurs here despite it not being
# an internal Keras loss function

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage:
 # y_pred = y_pred[:, 2:, :] 测试感觉没影响
 y_pred = y_pred[:, :, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

if __name__ == '__main__':
 height=150
 width=50
 input_tensor = Input((height, width, 1))
 x = input_tensor
 for i in range(3):
 x = Convolution2D(32*2**i, (3, 3), activation='relu', padding='same')(x)
 # x = Convolution2D(32*2**i, (3, 3), activation='relu')(x)
 x = MaxPooling2D(pool_size=(2, 2))(x)

 conv_shape = x.get_shape()
 # print(conv_shape)
 x = Reshape(target_shape=(int(conv_shape[1]), int(conv_shape[2] * conv_shape[3])))(x)

 x = Dense(32, activation='relu')(x)

 gru_1 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru1')(x)
 gru_1b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(x)
 gru1_merged = add([gru_1, gru_1b]) ###################

 gru_2 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)
 gru_2b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(
 gru1_merged)
 x = concatenate([gru_2, gru_2b]) ######################
 x = Dropout(0.25)(x)
 x = Dense(label_count, kernel_initializer='he_normal', activation='softmax')(x)
 base_model = Model(inputs=input_tensor, outputs=x)

 labels = Input(name='the_labels', shape=[seq_len], dtype='float32')
 input_length = Input(name='input_length', shape=[1], dtype='int64')
 label_length = Input(name='label_length', shape=[1], dtype='int64')
 loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([x, labels, input_length, label_length])

 model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])
 model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
 model.summary()

 def test(base_model):
 file_list = []
 X, Y = gen_image_data(r'data\test', file_list)
 y_pred = base_model.predict(X)
 shape = y_pred[:, :, :].shape # 2:
 out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:,
  :seq_len] # 2:
 print()
 error_count=0
 for i in range(len(X)):
  print(file_list[i])
  str_src = str(os.path.split(file_list[i])[-1]).split('.')[0].split('_')[-1]
  print(out[i])
  str_out = ''.join([str(x) for x in out[i] if x!=-1 ])
  print(str_src, str_out)
  if str_src!=str_out:
  error_count+=1
  print('################################',error_count)
  # img = cv2.imread(file_list[i])
  # cv2.imshow('image', img)
  # cv2.waitKey()

 class LossHistory(Callback):
 def on_train_begin(self, logs={}):
  self.losses = []

 def on_epoch_end(self, epoch, logs=None):
  model.save_weights('model_1018.w')
  base_model.save_weights('base_model_1018.w')
  test(base_model)

 def on_batch_end(self, batch, logs={}):
  self.losses.append(logs.get('loss'))


 # checkpointer = ModelCheckpoint(filepath="keras_seq2seq_1018.hdf5", verbose=1, save_best_only=True, )
 history = LossHistory()

 # base_model.load_weights('base_model_1018.w')
 # model.load_weights('model_1018.w')

 X,Y=gen_image_data()
 maxin=4900
 subseq_size = 100
 batch_size=10
 result=model.fit([X[:maxin], Y[:maxin], np.array(np.ones(len(X))*int(conv_shape[1]))[:maxin], np.array(np.ones(len(X))*seq_len)[:maxin]], Y[:maxin],
   batch_size=20,
   epochs=1000,
   callbacks=[history, plotter, EarlyStopping(patience=10)], #checkpointer, history,
   validation_data=([X[maxin:], Y[maxin:], np.array(np.ones(len(X))*int(conv_shape[1]))[maxin:], np.array(np.ones(len(X))*seq_len)[maxin:]], Y[maxin:]),
   )

 test(base_model)

 K.clear_session()

补充知识:日常填坑之keras.backend.ctc_batch_cost参数问题

InvalidArgumentError sequence_length(0) <=30错误

下面的代码是在网上绝大多数文章给出的关于k.ctc_batch_cost()函数的使用代码

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage: 
 y_pred = y_pred[:, 2:, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

可以注意到有一句:y_pred = y_pred[:, 2:, :],这里把y_pred 的第二维数据去掉了两列,说人话:把送进lstm序列的step减了2步。后来偶然在一篇文章中有提到说这里之所以减2是因为在将feature送入keras的lstm时自动少了2维,所以这里就写成这样了。估计是之前老版本的bug,现在的新版本已经修复了。如果依然按照上面的写法,会得到如下错误:

InvalidArgumentError sequence_length(0) <=30

'<='后面的数值 = 你cnn最后的输出维度 - 2。这个错误我找了很久,一直不明白30哪里来的,后来一行行的检查代码是发现了这里很可疑,于是改成如下形式错误解决。

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args 
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

训练时出现ctc_loss_calculator.cc:144] No valid path found或loss: inf错误

熟悉CTC算法的话,这个提示应该是ctc没找到有效路径。既然是没找到有效路径,那肯定是label和input之间哪个地方又出问题了!和input相关的错误已经解决了,那么肯定就是label的问题了。再看ctc_batch_cost的四个参数,labels和label_length这两个地方有可疑。对于ctc_batch_cost()的参数,labels需要one-hot编码,形状:[batch, max_labelLength],其中max_labelLength指预测的最大字符长度;label_length就是每个label中的字符长度了,受之前tf.ctc_loss的影响把这里都设置成了最大长度,所以报错。

对于参数labels而言,max_labelLength是能预测的最大字符长度。这个值与送lstm的featue的第二维,即特征序列的max_step有关,表面上看只要max_labelLength<max_step即可,但是如果小的不多依然会出现上述错误。至于到底要小多少,还得从ctc算法里找,由于ctc算法在标签中的每个字符后都加了一个空格,所以应该把这个长度考虑进去,所以有 max_labelLength < max_step//2。没仔细研究keras里ctc_batch_cost()函数的实现细节,上面是我的猜测。如果有很明确的答案,还请麻烦告诉我一声,谢了先!

错误代码:

batch_label_length = np.ones(batch_size) * max_labelLength

正确打开方式:

batch_x, batch_y = [], []
batch_input_length = np.ones(batch_size) * (max_img_weigth//8)
batch_label_length = []
for j in range(i, i + batch_size):
 x, y = self.get_img_data(index_all[j])
 batch_x.append(x)
 batch_y.append(y)
 batch_label_length.append(self.label_length[j])

最后附一张我的crnn的模型图:

使用keras框架cnn+ctc_loss识别不定长字符图片操作

以上这篇使用keras框架cnn+ctc_loss识别不定长字符图片操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中集合类型(set)学习小结
Jan 28 Python
python制作websocket服务器实例分享
Nov 20 Python
基于python select.select模块通信的实例讲解
Sep 21 Python
python list数据等间隔抽取并新建list存储的例子
Nov 27 Python
python pandas移动窗口函数rolling的用法
Feb 29 Python
Python3之乱码\xe6\x97\xa0\xe6\xb3\x95处理方式
May 11 Python
pyCharm 设置调试输出窗口中文显示方式(字符码转换)
Jun 09 Python
详解如何在PyCharm控制台中输出彩色文字和背景
Aug 17 Python
如何在scrapy中集成selenium爬取网页的方法
Nov 18 Python
python实现自动化群控的步骤
Apr 11 Python
Python实现制作销售数据可视化看板详解
Nov 27 Python
python实现学生信息管理系统(面向对象)
Jun 05 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
keras K.function获取某层的输出操作
Jun 29 #Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
You might like
php 字符转义 注意事项
2009/05/27 PHP
PHP中使用gettext来支持多语言的方法
2011/05/02 PHP
PHP的Yii框架的常用日志操作总结
2015/12/08 PHP
浅析php如何实现App常用的秒发功能
2016/08/03 PHP
PHP文件上传、客户端和服务器端加限制、抓取错误信息、完整步骤解析
2017/01/12 PHP
Laravel中9个不经常用的小技巧汇总
2019/04/16 PHP
jQuery 树形结构的选择器
2010/02/15 Javascript
基于Jquery的文字自动截取(提供源代码)
2011/08/09 Javascript
字符串的replace方法应用浅析
2011/12/06 Javascript
AngularJS入门知识之MVW类框架的编程思想探讨
2014/12/08 Javascript
yui3的AOP(面向切面编程)和OOP(面向对象编程)
2015/05/01 Javascript
NodeJS远程代码执行
2016/08/28 NodeJs
js实现密码强度检验
2017/01/15 Javascript
JS通过ajax + 多列布局 + 自动加载实现瀑布流效果
2019/05/30 Javascript
[53:13]2014 DOTA2国际邀请赛中国区预选赛5.21 DT VS LGD-GAMING
2014/05/22 DOTA
py中的目录与文件判别代码
2008/07/16 Python
python画出三角形外接圆和内切圆的方法
2018/01/25 Python
使用python读取csv文件快速插入数据库的实例
2018/06/21 Python
浅谈django rest jwt vue 跨域问题
2018/10/26 Python
详解Python传入参数的几种方法
2019/05/16 Python
python实现批量视频分帧、保存视频帧
2019/05/31 Python
Python 堆叠柱状图绘制方法
2019/07/29 Python
Python安装及Pycharm安装使用教程图解
2019/09/20 Python
移动web模拟客户端实现多方框输入密码效果【附代码】
2016/03/25 HTML / CSS
HTML5中语义化 b 和 i 标签
2008/10/17 HTML / CSS
html5 迷宫游戏(碰撞检测)实例一
2013/07/25 HTML / CSS
Dr. Martens马汀博士德国官网:马丁靴鼻祖
2019/12/26 全球购物
超市促销实习自我鉴定
2013/09/23 职场文书
小学生评语集锦
2014/04/18 职场文书
汽车服务工程专业自荐信
2014/09/02 职场文书
2014年大堂经理工作总结
2014/11/21 职场文书
小学运动会加油稿
2015/07/22 职场文书
农村房屋租赁合同(范本)
2019/07/23 职场文书
Python 可迭代对象 iterable的具体使用
2021/08/07 Python
Vue OpenLayer 为地图绘制风场效果
2022/04/24 Vue.js
app场景下uniapp的扫码记录
2022/07/23 Java/Android