pytorch制作自己的LMDB数据操作示例


Posted in Python onDecember 18, 2019

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
 if imageBin is None:
  return False
 imageBuf = np.fromstring(imageBin, dtype=np.uint8)
 img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
 imgH, imgW = img.shape[0], img.shape[1]
 if imgH * imgW == 0:
  return False
 return True
def writeCache(env, cache):
 with env.begin(write=True) as txn:
  for k, v in cache.items():
   txn.put(k.encode(), v)
def _is_difficult(word):
 assert isinstance(word, str)
 return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
 """
 Create LMDB dataset for CRNN training.
 ARGS:
   outputPath  : LMDB output path
   imagePathList : list of image path
   labelList   : list of corresponding groundtruth texts
   lexiconList  : (optional) list of lexicon lists
   checkValid  : if true, check the validity of every image
 """
 assert(len(imagePathList) == len(labelList))
 nSamples = len(imagePathList)
 env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
 cache = {}
 cnt = 1
 for i in range(nSamples):
  imagePath = imagePathList[i]
  label = labelList[i]
  if len(label) == 0:
   continue
  if not os.path.exists(imagePath):
   print('%s does not exist' % imagePath)
   continue
  with open(imagePath, 'rb') as f:
   imageBin = f.read()
  if checkValid:
   if not checkImageIsValid(imageBin):
    print('%s is not a valid image' % imagePath)
    continue
  #数据库中都是二进制数据
  imageKey = 'image-%09d' % cnt#9位数不足填零
  labelKey = 'label-%09d' % cnt
  cache[imageKey] = imageBin
  cache[labelKey] = label.encode()
  if lexiconList:
   lexiconKey = 'lexicon-%09d' % cnt
   cache[lexiconKey] = ' '.join(lexiconList[i])
  if cnt % 1000 == 0:
   writeCache(env, cache)
   cache = {}
   print('Written %d / %d' % (cnt, nSamples))
  cnt += 1
 nSamples = cnt-1
 cache['num-samples'] = str(nSamples).encode()
 writeCache(env, cache)
 print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
  with open(txt_path,'r') as fr:
    jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
  txt_content_list=[]
  for jpg in jpg_list:
    label_path=jpg.replace('.jpg','.txt')
    with open(label_path,'r') as fr:
      try:
        str_tmp=fr.readline()
      except UnicodeDecodeError as e:
        print(label_path)
        raise(e)
      txt_content_list.append(str_tmp.strip())
  return jpg_list,txt_content_list
if __name__ == "__main__":
 txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
 lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
 imagePathList,labelList=get_sample_list(txt_path)
 createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
 import moxing as mox
 #moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
 def __init__(self, root, voc_type, max_len, num_samples, transform=None):
  super(LmdbDataset, self).__init__()
  if global_args.run_on_remote:
   dataset_name = os.path.basename(root)
   data_cache_url = "/cache/%s" % dataset_name
   if not os.path.exists(data_cache_url):
    os.makedirs(data_cache_url)
   if mox.file.exists(root):
    mox.file.copy_parallel(root, data_cache_url)
   else:
    raise ValueError("%s not exists!" % root)
   self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
  else:
   self.env = lmdb.open(root, max_readers=32, readonly=True)
  assert self.env is not None, "cannot create lmdb from %s" % root
  self.txn = self.env.begin()
  self.voc_type = voc_type
  self.transform = transform
  self.max_len = max_len
  self.nSamples = int(self.txn.get(b"num-samples"))
  self.nSamples = min(self.nSamples, num_samples)
  assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
  self.EOS = 'EOS'
  self.PADDING = 'PADDING'
  self.UNKNOWN = 'UNKNOWN'
  self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
  self.char2id = dict(zip(self.voc, range(len(self.voc))))
  self.id2char = dict(zip(range(len(self.voc)), self.voc))
  self.rec_num_classes = len(self.voc)
  self.lowercase = (voc_type == 'LOWERCASE')
 def __len__(self):
  return self.nSamples
 def __getitem__(self, index):
  assert index <= len(self), 'index range error'
  index += 1
  img_key = b'image-%09d' % index
  imgbuf = self.txn.get(img_key)
  #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
  buf = six.BytesIO()
  buf.write(imgbuf)
  buf.seek(0)
  try:
   img = Image.open(buf).convert('RGB')
   # img = Image.open(buf).convert('L')
   # img = img.convert('RGB')
  except IOError:
   print('Corrupted image for %d' % index)
   return self[index + 1]
  # reconition labels
  label_key = b'label-%09d' % index
  word = self.txn.get(label_key).decode()
  if self.lowercase:
   word = word.lower()
  ## fill with the padding token
  label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
  label_list = []
  for char in word:
   if char in self.char2id:
    label_list.append(self.char2id[char])
   else:
    ## add the unknown token
    print('{0} is out of vocabulary.'.format(char))
    label_list.append(self.char2id[self.UNKNOWN])
  ## add a stop token
  label_list = label_list + [self.char2id[self.EOS]]
  assert len(label_list) <= self.max_len
  label[:len(label_list)] = np.array(label_list)
  if len(label) <= 0:
   return self[index + 1]
  # label length
  label_len = len(label_list)
  if self.transform is not None:
   img = self.transform(img)
  return img, label, label_len

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
在python的WEB框架Flask中使用多个配置文件的解决方法
Apr 18 Python
Python接收Gmail新邮件并发送到gtalk的方法
Mar 10 Python
Python中列表的一些基本操作知识汇总
May 20 Python
Python数据结构与算法之列表(链表,linked list)简单实现
Oct 30 Python
Python3中的列表,元组,字典,字符串相关知识小结
Nov 10 Python
python3读取csv和xlsx文件的实例
Jun 22 Python
浅谈python3.6的tkinter运行问题
Feb 22 Python
聊聊python里如何用Borg pattern实现的单例模式
Jun 06 Python
Django Admin中增加导出CSV功能过程解析
Sep 04 Python
通过实例解析Python RPC实现原理及方法
Jul 07 Python
详解如何使用Pytest进行自动化测试
Jan 14 Python
python爬虫用request库处理cookie的实例讲解
Feb 20 Python
Python Gluon参数和模块命名操作教程
Dec 18 #Python
python turtle 绘制太极图的实例
Dec 18 #Python
Python使用gluon/mxnet模块实现的mnist手写数字识别功能完整示例
Dec 18 #Python
简单了解Python读取大文件代码实例
Dec 18 #Python
python 比较2张图片的相似度的方法示例
Dec 18 #Python
使用Python的Turtle库绘制森林的实例
Dec 18 #Python
python3 requests库实现多图片爬取教程
Dec 18 #Python
You might like
PHP实现的XML操作类【XML Library】
2016/12/29 PHP
javascript iframe中打开文件,并检测iframe存在否
2008/12/28 Javascript
javascript打开新窗口同时关闭旧窗口
2009/01/16 Javascript
js动态添加表格数据使用insertRow和insertCell实现
2014/05/22 Javascript
JavaScript中跨域调用Flash的方法
2014/08/11 Javascript
js怎么判断flash swf文件是否加载完毕
2014/08/14 Javascript
Node.js静态文件服务器改进版
2016/01/10 Javascript
javascript瀑布流式图片懒加载实例解析与优化
2016/02/23 Javascript
jQuery绑定事件on()与弹窗的简要概述
2016/04/27 Javascript
AngularJS基于provider实现全局变量的读取和赋值方法
2017/06/28 Javascript
JS实现页面内跳转的简单代码
2017/09/03 Javascript
javascript 中模板方法单例的实现方法
2017/10/17 Javascript
浅谈ES6 模板字符串的具体使用方法
2017/11/07 Javascript
JavaScript检测是否开启了控制台(F12调试工具)
2020/10/02 Javascript
Vue实现多页签组件
2021/01/14 Vue.js
Python ljust rjust center输出
2008/09/06 Python
Python translator使用实例
2008/09/06 Python
Python实现返回数组中第i小元素的方法示例
2017/12/04 Python
用Python写脚本,实现完全备份和增量备份的示例
2018/04/29 Python
详解PANDAS 数据合并与重塑(join/merge篇)
2019/07/09 Python
python 实现手机自动拨打电话的方法(通话压力测试)
2019/08/08 Python
python3实现的zip格式压缩文件夹操作示例
2019/08/17 Python
Django实现内容缓存实例方法
2020/06/30 Python
opencv 图像滤波(均值,方框,高斯,中值)
2020/07/08 Python
Python配置pip国内镜像源的实现
2020/08/20 Python
关于python中导入文件到list的问题
2020/10/31 Python
matplotlib更改窗口图标的方法示例
2021/02/03 Python
施华洛世奇水晶荷兰官方网站:SWAROVSKI荷兰
2017/05/12 全球购物
运动会闭幕式解说词
2014/02/21 职场文书
优秀经理获奖感言
2014/03/04 职场文书
求职者怎样写自荐信
2014/04/13 职场文书
企业公益活动策划方案
2014/08/24 职场文书
民族学专业职业生涯规划范文:积跬步以至千里
2014/09/11 职场文书
创业计划书之熟食店
2019/10/16 职场文书
Python包管理工具pip的15 个使用小技巧
2021/05/17 Python
收音机爱好者玩机13年,简评其使用过的19台收音机
2022/04/30 无线电