使用keras框架cnn+ctc_loss识别不定长字符图片操作


Posted in Python onJune 29, 2020

我就废话不多说了,大家还是直接看代码吧~

# -*- coding: utf-8 -*-
#keras==2.0.5
#tensorflow==1.1.0

import os,sys,string
import sys
import logging
import multiprocessing
import time
import json
import cv2
import numpy as np
from sklearn.model_selection import train_test_split

import keras
import keras.backend as K
from keras.datasets import mnist
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import *
from keras import backend as K
# from keras.utils.visualize_util import plot
from visual_callbacks import AccLossPlotter
plotter = AccLossPlotter(graphs=['acc', 'loss'], save_graph=True, save_graph_path=sys.path[0])

#识别字符集
char_ocr='0123456789' #string.digits
#定义识别字符串的最大长度
seq_len=8
#识别结果集合个数 0-9
label_count=len(char_ocr)+1

def get_label(filepath):
 # print(str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1])
 lab=[]
 for num in str(os.path.split(filepath)[-1]).split('.')[0].split('_')[-1]:
 lab.append(int(char_ocr.find(num)))
 if len(lab) < seq_len:
 cur_seq_len = len(lab)
 for i in range(seq_len - cur_seq_len):
  lab.append(label_count) #
 return lab

def gen_image_data(dir=r'data\train', file_list=[]):
 dir_path = dir
 for rt, dirs, files in os.walk(dir_path): # =pathDir
 for filename in files:
  # print (filename)
  if filename.find('.') >= 0:
  (shotname, extension) = os.path.splitext(filename)
  # print shotname,extension
  if extension == '.tif': # extension == '.png' or
   file_list.append(os.path.join('%s\\%s' % (rt, filename)))
   # print (filename)

 print(len(file_list))
 index = 0
 X = []
 Y = []
 for file in file_list:

 index += 1
 # if index>1000:
 # break
 # print(file)
 img = cv2.imread(file, 0)
 # print(np.shape(img))
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 img = cv2.resize(img, (150, 50), interpolation=cv2.INTER_CUBIC)
 img = cv2.transpose(img,(50,150))
 img =cv2.flip(img,1)
 # cv2.namedWindow("the window")
 # cv2.imshow("the window",img)
 # cv2.waitKey()
 img = (255 - img) / 256 # 反色处理
 X.append([img])
 Y.append(get_label(file))
 # print(get_label(file))
 # print(np.shape(X))
 # print(np.shape(X))

 # print(np.shape(X))
 X = np.transpose(X, (0, 2, 3, 1))
 X = np.array(X)
 Y = np.array(Y)
 return X,Y

# the actual loss calc occurs here despite it not being
# an internal Keras loss function

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage:
 # y_pred = y_pred[:, 2:, :] 测试感觉没影响
 y_pred = y_pred[:, :, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

if __name__ == '__main__':
 height=150
 width=50
 input_tensor = Input((height, width, 1))
 x = input_tensor
 for i in range(3):
 x = Convolution2D(32*2**i, (3, 3), activation='relu', padding='same')(x)
 # x = Convolution2D(32*2**i, (3, 3), activation='relu')(x)
 x = MaxPooling2D(pool_size=(2, 2))(x)

 conv_shape = x.get_shape()
 # print(conv_shape)
 x = Reshape(target_shape=(int(conv_shape[1]), int(conv_shape[2] * conv_shape[3])))(x)

 x = Dense(32, activation='relu')(x)

 gru_1 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru1')(x)
 gru_1b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(x)
 gru1_merged = add([gru_1, gru_1b]) ###################

 gru_2 = GRU(32, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)
 gru_2b = GRU(32, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(
 gru1_merged)
 x = concatenate([gru_2, gru_2b]) ######################
 x = Dropout(0.25)(x)
 x = Dense(label_count, kernel_initializer='he_normal', activation='softmax')(x)
 base_model = Model(inputs=input_tensor, outputs=x)

 labels = Input(name='the_labels', shape=[seq_len], dtype='float32')
 input_length = Input(name='input_length', shape=[1], dtype='int64')
 label_length = Input(name='label_length', shape=[1], dtype='int64')
 loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([x, labels, input_length, label_length])

 model = Model(inputs=[input_tensor, labels, input_length, label_length], outputs=[loss_out])
 model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')
 model.summary()

 def test(base_model):
 file_list = []
 X, Y = gen_image_data(r'data\test', file_list)
 y_pred = base_model.predict(X)
 shape = y_pred[:, :, :].shape # 2:
 out = K.get_value(K.ctc_decode(y_pred[:, :, :], input_length=np.ones(shape[0]) * shape[1])[0][0])[:,
  :seq_len] # 2:
 print()
 error_count=0
 for i in range(len(X)):
  print(file_list[i])
  str_src = str(os.path.split(file_list[i])[-1]).split('.')[0].split('_')[-1]
  print(out[i])
  str_out = ''.join([str(x) for x in out[i] if x!=-1 ])
  print(str_src, str_out)
  if str_src!=str_out:
  error_count+=1
  print('################################',error_count)
  # img = cv2.imread(file_list[i])
  # cv2.imshow('image', img)
  # cv2.waitKey()

 class LossHistory(Callback):
 def on_train_begin(self, logs={}):
  self.losses = []

 def on_epoch_end(self, epoch, logs=None):
  model.save_weights('model_1018.w')
  base_model.save_weights('base_model_1018.w')
  test(base_model)

 def on_batch_end(self, batch, logs={}):
  self.losses.append(logs.get('loss'))


 # checkpointer = ModelCheckpoint(filepath="keras_seq2seq_1018.hdf5", verbose=1, save_best_only=True, )
 history = LossHistory()

 # base_model.load_weights('base_model_1018.w')
 # model.load_weights('model_1018.w')

 X,Y=gen_image_data()
 maxin=4900
 subseq_size = 100
 batch_size=10
 result=model.fit([X[:maxin], Y[:maxin], np.array(np.ones(len(X))*int(conv_shape[1]))[:maxin], np.array(np.ones(len(X))*seq_len)[:maxin]], Y[:maxin],
   batch_size=20,
   epochs=1000,
   callbacks=[history, plotter, EarlyStopping(patience=10)], #checkpointer, history,
   validation_data=([X[maxin:], Y[maxin:], np.array(np.ones(len(X))*int(conv_shape[1]))[maxin:], np.array(np.ones(len(X))*seq_len)[maxin:]], Y[maxin:]),
   )

 test(base_model)

 K.clear_session()

补充知识:日常填坑之keras.backend.ctc_batch_cost参数问题

InvalidArgumentError sequence_length(0) <=30错误

下面的代码是在网上绝大多数文章给出的关于k.ctc_batch_cost()函数的使用代码

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args
 # the 2 is critical here since the first couple outputs of the RNN
 # tend to be garbage: 
 y_pred = y_pred[:, 2:, :]
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

可以注意到有一句:y_pred = y_pred[:, 2:, :],这里把y_pred 的第二维数据去掉了两列,说人话:把送进lstm序列的step减了2步。后来偶然在一篇文章中有提到说这里之所以减2是因为在将feature送入keras的lstm时自动少了2维,所以这里就写成这样了。估计是之前老版本的bug,现在的新版本已经修复了。如果依然按照上面的写法,会得到如下错误:

InvalidArgumentError sequence_length(0) <=30

'<='后面的数值 = 你cnn最后的输出维度 - 2。这个错误我找了很久,一直不明白30哪里来的,后来一行行的检查代码是发现了这里很可疑,于是改成如下形式错误解决。

def ctc_lambda_func(args):
 y_pred, labels, input_length, label_length = args 
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

训练时出现ctc_loss_calculator.cc:144] No valid path found或loss: inf错误

熟悉CTC算法的话,这个提示应该是ctc没找到有效路径。既然是没找到有效路径,那肯定是label和input之间哪个地方又出问题了!和input相关的错误已经解决了,那么肯定就是label的问题了。再看ctc_batch_cost的四个参数,labels和label_length这两个地方有可疑。对于ctc_batch_cost()的参数,labels需要one-hot编码,形状:[batch, max_labelLength],其中max_labelLength指预测的最大字符长度;label_length就是每个label中的字符长度了,受之前tf.ctc_loss的影响把这里都设置成了最大长度,所以报错。

对于参数labels而言,max_labelLength是能预测的最大字符长度。这个值与送lstm的featue的第二维,即特征序列的max_step有关,表面上看只要max_labelLength<max_step即可,但是如果小的不多依然会出现上述错误。至于到底要小多少,还得从ctc算法里找,由于ctc算法在标签中的每个字符后都加了一个空格,所以应该把这个长度考虑进去,所以有 max_labelLength < max_step//2。没仔细研究keras里ctc_batch_cost()函数的实现细节,上面是我的猜测。如果有很明确的答案,还请麻烦告诉我一声,谢了先!

错误代码:

batch_label_length = np.ones(batch_size) * max_labelLength

正确打开方式:

batch_x, batch_y = [], []
batch_input_length = np.ones(batch_size) * (max_img_weigth//8)
batch_label_length = []
for j in range(i, i + batch_size):
 x, y = self.get_img_data(index_all[j])
 batch_x.append(x)
 batch_y.append(y)
 batch_label_length.append(self.label_length[j])

最后附一张我的crnn的模型图:

使用keras框架cnn+ctc_loss识别不定长字符图片操作

以上这篇使用keras框架cnn+ctc_loss识别不定长字符图片操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例介绍Python中的25个隐藏特性
Mar 30 Python
使用python生成目录树
Mar 29 Python
神经网络(BP)算法Python实现及应用
Apr 16 Python
Flask框架信号用法实例分析
Jul 24 Python
解决python中无法自动补全代码的问题
Dec 04 Python
python reverse反转部分数组的实例
Dec 13 Python
解决pycharm每次新建项目都要重新安装一些第三方库的问题
Jan 17 Python
pytorch-神经网络拟合曲线实例
Jan 15 Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 Python
Python JSON常用编解码方法代码实例
Sep 05 Python
Python APScheduler执行使用方法详解
Dec 10 Python
自动在Windows中运行Python脚本并定时触发功能实现
Sep 04 Python
浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
Jun 29 #Python
如何使用python记录室友的抖音在线时间
Jun 29 #Python
Python sublime安装及配置过程详解
Jun 29 #Python
keras K.function获取某层的输出操作
Jun 29 #Python
Python pytesseract验证码识别库用法解析
Jun 29 #Python
用Python开发app后端有优势吗
Jun 29 #Python
在keras里实现自定义上采样层
Jun 28 #Python
You might like
php 全局变量范围分析
2009/08/07 PHP
使用HMAC-SHA1签名方法详解
2013/06/26 PHP
php操作redis中的hash和zset类型数据的方法和代码例子
2014/07/05 PHP
PHP中使用asort进行中文排序失效的问题处理
2014/08/18 PHP
一款简单实用的php操作mysql数据库类
2014/12/08 PHP
smarty模板引擎中内建函数if、elseif和else的使用方法
2015/01/22 PHP
使用Thinkphp框架开发移动端接口
2015/08/05 PHP
PhpStorm terminal无法输入命令的解决方法
2016/10/09 PHP
在jquery中处理带有命名空间的XML数据
2011/06/13 Javascript
jQuery图片滚动图片的效果(另类实现)
2013/06/02 Javascript
jQuery截取指定长度字符串代码
2014/08/21 Javascript
jquery中checkbox全选失效的解决方法
2014/12/26 Javascript
Jquery Ajax Error 调试错误的技巧
2015/11/20 Javascript
js简单网速测试方法完整实例
2015/12/15 Javascript
在JavaScript中call()与apply()区别
2016/01/22 Javascript
Node.js Addons翻译(C/C++扩展)
2016/06/12 Javascript
JS动态添加的div点击跳转到另一页面实现代码
2017/09/30 Javascript
基于Proxy的小程序状态管理实现
2019/06/14 Javascript
jQuery实现弹幕特效
2019/11/29 jQuery
JavaScript组合模式---引入案例分析
2020/05/23 Javascript
Python实现的堆排序算法原理与用法实例分析
2017/11/22 Python
Python使用requests及BeautifulSoup构建爬虫实例代码
2018/01/24 Python
python监控文件并且发送告警邮件
2018/06/21 Python
用python统计代码行的示例(包括空行和注释)
2018/07/24 Python
Django unittest 设置跳过某些case的方法
2018/12/26 Python
Python全局锁中如何合理运用多线程(多进程)
2019/11/06 Python
Python实现进度条和时间预估的示例代码
2020/06/02 Python
解决python 执行sql语句时所传参数含有单引号的问题
2020/06/06 Python
Debenhams爱尔兰:英国知名的百货公司
2017/01/02 全球购物
Parts Express:音频、视频和扬声器的第一来源
2017/04/25 全球购物
初中升旗仪式演讲稿
2014/05/08 职场文书
副校长竞聘演讲稿
2014/09/01 职场文书
写给医生的感谢信
2015/01/22 职场文书
element多个表单校验的实现
2021/05/27 Javascript
关于使用Redisson订阅数问题
2022/01/18 Redis
小喇叭开始广播了! 四十多年前珍贵老照片
2022/05/09 无线电