pytorch制作自己的LMDB数据操作示例


Posted in Python onDecember 18, 2019

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
 if imageBin is None:
  return False
 imageBuf = np.fromstring(imageBin, dtype=np.uint8)
 img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
 imgH, imgW = img.shape[0], img.shape[1]
 if imgH * imgW == 0:
  return False
 return True
def writeCache(env, cache):
 with env.begin(write=True) as txn:
  for k, v in cache.items():
   txn.put(k.encode(), v)
def _is_difficult(word):
 assert isinstance(word, str)
 return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
 """
 Create LMDB dataset for CRNN training.
 ARGS:
   outputPath  : LMDB output path
   imagePathList : list of image path
   labelList   : list of corresponding groundtruth texts
   lexiconList  : (optional) list of lexicon lists
   checkValid  : if true, check the validity of every image
 """
 assert(len(imagePathList) == len(labelList))
 nSamples = len(imagePathList)
 env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
 cache = {}
 cnt = 1
 for i in range(nSamples):
  imagePath = imagePathList[i]
  label = labelList[i]
  if len(label) == 0:
   continue
  if not os.path.exists(imagePath):
   print('%s does not exist' % imagePath)
   continue
  with open(imagePath, 'rb') as f:
   imageBin = f.read()
  if checkValid:
   if not checkImageIsValid(imageBin):
    print('%s is not a valid image' % imagePath)
    continue
  #数据库中都是二进制数据
  imageKey = 'image-%09d' % cnt#9位数不足填零
  labelKey = 'label-%09d' % cnt
  cache[imageKey] = imageBin
  cache[labelKey] = label.encode()
  if lexiconList:
   lexiconKey = 'lexicon-%09d' % cnt
   cache[lexiconKey] = ' '.join(lexiconList[i])
  if cnt % 1000 == 0:
   writeCache(env, cache)
   cache = {}
   print('Written %d / %d' % (cnt, nSamples))
  cnt += 1
 nSamples = cnt-1
 cache['num-samples'] = str(nSamples).encode()
 writeCache(env, cache)
 print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
  with open(txt_path,'r') as fr:
    jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
  txt_content_list=[]
  for jpg in jpg_list:
    label_path=jpg.replace('.jpg','.txt')
    with open(label_path,'r') as fr:
      try:
        str_tmp=fr.readline()
      except UnicodeDecodeError as e:
        print(label_path)
        raise(e)
      txt_content_list.append(str_tmp.strip())
  return jpg_list,txt_content_list
if __name__ == "__main__":
 txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
 lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
 imagePathList,labelList=get_sample_list(txt_path)
 createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
 import moxing as mox
 #moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
 def __init__(self, root, voc_type, max_len, num_samples, transform=None):
  super(LmdbDataset, self).__init__()
  if global_args.run_on_remote:
   dataset_name = os.path.basename(root)
   data_cache_url = "/cache/%s" % dataset_name
   if not os.path.exists(data_cache_url):
    os.makedirs(data_cache_url)
   if mox.file.exists(root):
    mox.file.copy_parallel(root, data_cache_url)
   else:
    raise ValueError("%s not exists!" % root)
   self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
  else:
   self.env = lmdb.open(root, max_readers=32, readonly=True)
  assert self.env is not None, "cannot create lmdb from %s" % root
  self.txn = self.env.begin()
  self.voc_type = voc_type
  self.transform = transform
  self.max_len = max_len
  self.nSamples = int(self.txn.get(b"num-samples"))
  self.nSamples = min(self.nSamples, num_samples)
  assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
  self.EOS = 'EOS'
  self.PADDING = 'PADDING'
  self.UNKNOWN = 'UNKNOWN'
  self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
  self.char2id = dict(zip(self.voc, range(len(self.voc))))
  self.id2char = dict(zip(range(len(self.voc)), self.voc))
  self.rec_num_classes = len(self.voc)
  self.lowercase = (voc_type == 'LOWERCASE')
 def __len__(self):
  return self.nSamples
 def __getitem__(self, index):
  assert index <= len(self), 'index range error'
  index += 1
  img_key = b'image-%09d' % index
  imgbuf = self.txn.get(img_key)
  #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
  buf = six.BytesIO()
  buf.write(imgbuf)
  buf.seek(0)
  try:
   img = Image.open(buf).convert('RGB')
   # img = Image.open(buf).convert('L')
   # img = img.convert('RGB')
  except IOError:
   print('Corrupted image for %d' % index)
   return self[index + 1]
  # reconition labels
  label_key = b'label-%09d' % index
  word = self.txn.get(label_key).decode()
  if self.lowercase:
   word = word.lower()
  ## fill with the padding token
  label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
  label_list = []
  for char in word:
   if char in self.char2id:
    label_list.append(self.char2id[char])
   else:
    ## add the unknown token
    print('{0} is out of vocabulary.'.format(char))
    label_list.append(self.char2id[self.UNKNOWN])
  ## add a stop token
  label_list = label_list + [self.char2id[self.EOS]]
  assert len(label_list) <= self.max_len
  label[:len(label_list)] = np.array(label_list)
  if len(label) <= 0:
   return self[index + 1]
  # label length
  label_len = len(label_list)
  if self.transform is not None:
   img = self.transform(img)
  return img, label, label_len

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python Web开发模板引擎优缺点总结
May 06 Python
Python是编译运行的验证方法
Jan 30 Python
python搜索指定目录的方法
Apr 29 Python
PyQt 线程类 QThread使用详解
Jul 16 Python
Python人工智能之路 之PyAudio 实现录音 自动化交互实现问答
Aug 13 Python
pytorch 常用线性函数详解
Jan 15 Python
TensorFlow2.1.0最新版本安装详细教程
Apr 08 Python
python线程池如何使用
May 28 Python
python爬取”顶点小说网“《纯阳剑尊》的示例代码
Oct 16 Python
python链表类中获取元素实例方法
Feb 23 Python
Jupyter Notebook内使用argparse报错的解决方案
Jun 03 Python
python人工智能human learn绘图可创建机器学习模型
Nov 23 Python
Python Gluon参数和模块命名操作教程
Dec 18 #Python
python turtle 绘制太极图的实例
Dec 18 #Python
Python使用gluon/mxnet模块实现的mnist手写数字识别功能完整示例
Dec 18 #Python
简单了解Python读取大文件代码实例
Dec 18 #Python
python 比较2张图片的相似度的方法示例
Dec 18 #Python
使用Python的Turtle库绘制森林的实例
Dec 18 #Python
python3 requests库实现多图片爬取教程
Dec 18 #Python
You might like
建立动态的WML站点(一)
2006/10/09 PHP
兼容性最强的PHP生成缩略图的函数代码(修改版)
2011/01/18 PHP
php ci框架验证码实例分析
2013/06/26 PHP
PHP中使用localhost连接Mysql不成功的解决方法
2014/08/20 PHP
MyEclipse常用配置图文教程
2014/09/11 PHP
phpstorm编辑器乱码问题解决
2014/12/01 PHP
PHP正则表达式过滤html标签属性(DEMO)
2016/05/04 PHP
一个简单的JS鼠标悬停特效具体方法
2013/06/17 Javascript
jQuery.position()方法获取不到值的安全替换方法
2015/03/13 Javascript
js带点自动图片轮播幻灯片特效代码分享
2015/09/07 Javascript
JS实现的竖向折叠菜单代码
2015/10/21 Javascript
JS鼠标拖拽实例分析
2015/11/23 Javascript
基于Javascript实现弹出页面效果
2016/01/01 Javascript
JavaScript知识点总结(五)之Javascript中两个等于号(==)和三个等于号(===)的区别
2016/05/31 Javascript
JS图片放大效果简单实现代码
2016/09/08 Javascript
JavaScript构建自己的对象示例
2016/11/29 Javascript
jQuery实现简易的输入框字数计数功能示例
2017/01/16 Javascript
vue+vuecli+webpack中使用mockjs模拟后端数据的示例
2017/10/24 Javascript
mui框架 页面无法滚动的解决方法(推荐)
2018/01/25 Javascript
微信小程序五子棋游戏的悔棋实现方法【附demo源码下载】
2019/02/20 Javascript
JS简单数组排序操作示例【sort方法】
2019/05/17 Javascript
Vue的click事件防抖和节流处理详解
2019/11/13 Javascript
Python利用operator模块实现对象的多级排序详解
2017/05/09 Python
Python 爬虫图片简单实现
2017/06/01 Python
Python 获取当前所在目录的方法详解
2017/08/02 Python
python画出三角形外接圆和内切圆的方法
2018/01/25 Python
numpy实现合并多维矩阵、list的扩展方法
2018/05/08 Python
python获取微信企业号打卡数据并生成windows计划任务
2019/04/30 Python
Python实现快速大文件比较代码解析
2020/09/04 Python
Django Admin后台模型列表页面如何添加自定义操作按钮
2020/11/11 Python
深圳-东方伟业笔试部分
2015/02/11 面试题
《乡愁》教学反思
2014/02/18 职场文书
活动总结怎么写啊
2014/05/07 职场文书
2016年教师党员创先争优承诺书
2016/03/24 职场文书
MySQL创建高性能索引的全步骤
2021/05/02 MySQL
Spring Boot mybatis-config 和 log4j 输出sql 日志的方式
2021/07/26 Java/Android