pytorch制作自己的LMDB数据操作示例


Posted in Python onDecember 18, 2019

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
 if imageBin is None:
  return False
 imageBuf = np.fromstring(imageBin, dtype=np.uint8)
 img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
 imgH, imgW = img.shape[0], img.shape[1]
 if imgH * imgW == 0:
  return False
 return True
def writeCache(env, cache):
 with env.begin(write=True) as txn:
  for k, v in cache.items():
   txn.put(k.encode(), v)
def _is_difficult(word):
 assert isinstance(word, str)
 return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
 """
 Create LMDB dataset for CRNN training.
 ARGS:
   outputPath  : LMDB output path
   imagePathList : list of image path
   labelList   : list of corresponding groundtruth texts
   lexiconList  : (optional) list of lexicon lists
   checkValid  : if true, check the validity of every image
 """
 assert(len(imagePathList) == len(labelList))
 nSamples = len(imagePathList)
 env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
 cache = {}
 cnt = 1
 for i in range(nSamples):
  imagePath = imagePathList[i]
  label = labelList[i]
  if len(label) == 0:
   continue
  if not os.path.exists(imagePath):
   print('%s does not exist' % imagePath)
   continue
  with open(imagePath, 'rb') as f:
   imageBin = f.read()
  if checkValid:
   if not checkImageIsValid(imageBin):
    print('%s is not a valid image' % imagePath)
    continue
  #数据库中都是二进制数据
  imageKey = 'image-%09d' % cnt#9位数不足填零
  labelKey = 'label-%09d' % cnt
  cache[imageKey] = imageBin
  cache[labelKey] = label.encode()
  if lexiconList:
   lexiconKey = 'lexicon-%09d' % cnt
   cache[lexiconKey] = ' '.join(lexiconList[i])
  if cnt % 1000 == 0:
   writeCache(env, cache)
   cache = {}
   print('Written %d / %d' % (cnt, nSamples))
  cnt += 1
 nSamples = cnt-1
 cache['num-samples'] = str(nSamples).encode()
 writeCache(env, cache)
 print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
  with open(txt_path,'r') as fr:
    jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
  txt_content_list=[]
  for jpg in jpg_list:
    label_path=jpg.replace('.jpg','.txt')
    with open(label_path,'r') as fr:
      try:
        str_tmp=fr.readline()
      except UnicodeDecodeError as e:
        print(label_path)
        raise(e)
      txt_content_list.append(str_tmp.strip())
  return jpg_list,txt_content_list
if __name__ == "__main__":
 txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
 lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
 imagePathList,labelList=get_sample_list(txt_path)
 createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
 import moxing as mox
 #moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
 def __init__(self, root, voc_type, max_len, num_samples, transform=None):
  super(LmdbDataset, self).__init__()
  if global_args.run_on_remote:
   dataset_name = os.path.basename(root)
   data_cache_url = "/cache/%s" % dataset_name
   if not os.path.exists(data_cache_url):
    os.makedirs(data_cache_url)
   if mox.file.exists(root):
    mox.file.copy_parallel(root, data_cache_url)
   else:
    raise ValueError("%s not exists!" % root)
   self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
  else:
   self.env = lmdb.open(root, max_readers=32, readonly=True)
  assert self.env is not None, "cannot create lmdb from %s" % root
  self.txn = self.env.begin()
  self.voc_type = voc_type
  self.transform = transform
  self.max_len = max_len
  self.nSamples = int(self.txn.get(b"num-samples"))
  self.nSamples = min(self.nSamples, num_samples)
  assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
  self.EOS = 'EOS'
  self.PADDING = 'PADDING'
  self.UNKNOWN = 'UNKNOWN'
  self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
  self.char2id = dict(zip(self.voc, range(len(self.voc))))
  self.id2char = dict(zip(range(len(self.voc)), self.voc))
  self.rec_num_classes = len(self.voc)
  self.lowercase = (voc_type == 'LOWERCASE')
 def __len__(self):
  return self.nSamples
 def __getitem__(self, index):
  assert index <= len(self), 'index range error'
  index += 1
  img_key = b'image-%09d' % index
  imgbuf = self.txn.get(img_key)
  #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
  buf = six.BytesIO()
  buf.write(imgbuf)
  buf.seek(0)
  try:
   img = Image.open(buf).convert('RGB')
   # img = Image.open(buf).convert('L')
   # img = img.convert('RGB')
  except IOError:
   print('Corrupted image for %d' % index)
   return self[index + 1]
  # reconition labels
  label_key = b'label-%09d' % index
  word = self.txn.get(label_key).decode()
  if self.lowercase:
   word = word.lower()
  ## fill with the padding token
  label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
  label_list = []
  for char in word:
   if char in self.char2id:
    label_list.append(self.char2id[char])
   else:
    ## add the unknown token
    print('{0} is out of vocabulary.'.format(char))
    label_list.append(self.char2id[self.UNKNOWN])
  ## add a stop token
  label_list = label_list + [self.char2id[self.EOS]]
  assert len(label_list) <= self.max_len
  label[:len(label_list)] = np.array(label_list)
  if len(label) <= 0:
   return self[index + 1]
  # label length
  label_len = len(label_list)
  if self.transform is not None:
   img = self.transform(img)
  return img, label, label_len

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python字符串格式化
Jun 15 Python
Python制作爬虫采集小说
Oct 25 Python
简单讲解Python中的字符串与字符串的输入输出
Mar 13 Python
浅谈pyhton学习中出现的各种问题(新手必看)
May 17 Python
Python实现的人工神经网络算法示例【基于反向传播算法】
Nov 11 Python
Python通过OpenCV的findContours获取轮廓并切割实例
Jan 05 Python
Python使用matplotlib实现绘制自定义图形功能示例
Jan 18 Python
python实战之实现excel读取、统计、写入的示例讲解
May 02 Python
opencv python 图像去噪的实现方法
Aug 31 Python
Python Opencv提取图片中某种颜色组成的图形的方法
Sep 19 Python
win10从零安装配置pytorch全过程图文详解
May 08 Python
python 命令行传参方法总结
May 25 Python
Python Gluon参数和模块命名操作教程
Dec 18 #Python
python turtle 绘制太极图的实例
Dec 18 #Python
Python使用gluon/mxnet模块实现的mnist手写数字识别功能完整示例
Dec 18 #Python
简单了解Python读取大文件代码实例
Dec 18 #Python
python 比较2张图片的相似度的方法示例
Dec 18 #Python
使用Python的Turtle库绘制森林的实例
Dec 18 #Python
python3 requests库实现多图片爬取教程
Dec 18 #Python
You might like
将数字格式的计算结果转为汉字格式
2006/10/09 PHP
php设计模式 Chain Of Responsibility (职责链模式)
2011/06/26 PHP
PHPMYADMIN导入数据最大为2M的解决方法
2012/04/23 PHP
将时间以距今多久的形式表示,PHP,js双版本
2012/09/25 PHP
PHP实现提取一个图像文件并在浏览器上显示的代码
2012/10/06 PHP
php使用Cookie控制访问授权的方法
2015/01/21 PHP
PHP简单实现冒泡排序的方法
2016/12/26 PHP
PHP关键特性之命名空间实例详解
2017/05/06 PHP
基于 Swoole 的微信扫码登录功能实现代码
2018/01/15 PHP
ThinkPHP3.2.3框架Memcache缓存使用方法实例总结
2019/04/15 PHP
基于jQuery的Spin Button自定义文本框数值自增或自减
2010/07/17 Javascript
基于Jquery的实现回车键Enter切换焦点
2010/09/14 Javascript
Jquery easyUI 更新行示例
2014/03/06 Javascript
JavaScript实现穷举排列(permutation)算法谜题解答
2014/12/29 Javascript
JavaScript学习笔记之内置对象
2015/01/22 Javascript
使用jQueryMobile实现滑动翻页效果的方法
2015/02/04 Javascript
node-webkit打包成exe文件被360误报木马的解决方法
2015/03/11 Javascript
JavaScript实现级联菜单的方法
2015/06/29 Javascript
深入解读JavaScript中的Iterator和for-of循环
2015/07/28 Javascript
jQuery 3.0 的 setter和getter 模式详解
2016/07/11 Javascript
用jquery快速解决IE输入框不能输入的问题
2016/10/04 Javascript
D3.js进阶系列之CSV表格文件的读取详解
2017/06/06 Javascript
小程序数据通信方法大全(推荐)
2019/04/15 Javascript
vue watch关于对象内的属性监听
2019/04/22 Javascript
vue.js实现数据库的JSON数据输出渲染到html页面功能示例
2019/08/03 Javascript
浅谈tensorflow1.0 池化层(pooling)和全连接层(dense)
2018/04/27 Python
Python正则表达式指南 推荐
2018/10/09 Python
Python实现的银行系统模拟程序完整案例
2019/04/12 Python
Pycharm生成可执行文件.exe的实现方法
2020/06/02 Python
Keras设置以及获取权重的实现
2020/06/19 Python
自荐信写法介绍
2014/01/25 职场文书
《世界多美呀》教学反思
2016/02/22 职场文书
2019垃圾分类宣传口号汇总
2019/08/16 职场文书
PyTorch的Debug指南
2021/05/07 Python
Elasticsearch 数据类型及管理
2022/04/19 Python
科学家研发出新型速效酶,可在 24 小时内降解塑料制品
2022/04/29 数码科技