pytorch制作自己的LMDB数据操作示例


Posted in Python onDecember 18, 2019

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
 if imageBin is None:
  return False
 imageBuf = np.fromstring(imageBin, dtype=np.uint8)
 img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
 imgH, imgW = img.shape[0], img.shape[1]
 if imgH * imgW == 0:
  return False
 return True
def writeCache(env, cache):
 with env.begin(write=True) as txn:
  for k, v in cache.items():
   txn.put(k.encode(), v)
def _is_difficult(word):
 assert isinstance(word, str)
 return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
 """
 Create LMDB dataset for CRNN training.
 ARGS:
   outputPath  : LMDB output path
   imagePathList : list of image path
   labelList   : list of corresponding groundtruth texts
   lexiconList  : (optional) list of lexicon lists
   checkValid  : if true, check the validity of every image
 """
 assert(len(imagePathList) == len(labelList))
 nSamples = len(imagePathList)
 env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
 cache = {}
 cnt = 1
 for i in range(nSamples):
  imagePath = imagePathList[i]
  label = labelList[i]
  if len(label) == 0:
   continue
  if not os.path.exists(imagePath):
   print('%s does not exist' % imagePath)
   continue
  with open(imagePath, 'rb') as f:
   imageBin = f.read()
  if checkValid:
   if not checkImageIsValid(imageBin):
    print('%s is not a valid image' % imagePath)
    continue
  #数据库中都是二进制数据
  imageKey = 'image-%09d' % cnt#9位数不足填零
  labelKey = 'label-%09d' % cnt
  cache[imageKey] = imageBin
  cache[labelKey] = label.encode()
  if lexiconList:
   lexiconKey = 'lexicon-%09d' % cnt
   cache[lexiconKey] = ' '.join(lexiconList[i])
  if cnt % 1000 == 0:
   writeCache(env, cache)
   cache = {}
   print('Written %d / %d' % (cnt, nSamples))
  cnt += 1
 nSamples = cnt-1
 cache['num-samples'] = str(nSamples).encode()
 writeCache(env, cache)
 print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
  with open(txt_path,'r') as fr:
    jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
  txt_content_list=[]
  for jpg in jpg_list:
    label_path=jpg.replace('.jpg','.txt')
    with open(label_path,'r') as fr:
      try:
        str_tmp=fr.readline()
      except UnicodeDecodeError as e:
        print(label_path)
        raise(e)
      txt_content_list.append(str_tmp.strip())
  return jpg_list,txt_content_list
if __name__ == "__main__":
 txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
 lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
 imagePathList,labelList=get_sample_list(txt_path)
 createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
 import moxing as mox
 #moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
 def __init__(self, root, voc_type, max_len, num_samples, transform=None):
  super(LmdbDataset, self).__init__()
  if global_args.run_on_remote:
   dataset_name = os.path.basename(root)
   data_cache_url = "/cache/%s" % dataset_name
   if not os.path.exists(data_cache_url):
    os.makedirs(data_cache_url)
   if mox.file.exists(root):
    mox.file.copy_parallel(root, data_cache_url)
   else:
    raise ValueError("%s not exists!" % root)
   self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
  else:
   self.env = lmdb.open(root, max_readers=32, readonly=True)
  assert self.env is not None, "cannot create lmdb from %s" % root
  self.txn = self.env.begin()
  self.voc_type = voc_type
  self.transform = transform
  self.max_len = max_len
  self.nSamples = int(self.txn.get(b"num-samples"))
  self.nSamples = min(self.nSamples, num_samples)
  assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
  self.EOS = 'EOS'
  self.PADDING = 'PADDING'
  self.UNKNOWN = 'UNKNOWN'
  self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
  self.char2id = dict(zip(self.voc, range(len(self.voc))))
  self.id2char = dict(zip(range(len(self.voc)), self.voc))
  self.rec_num_classes = len(self.voc)
  self.lowercase = (voc_type == 'LOWERCASE')
 def __len__(self):
  return self.nSamples
 def __getitem__(self, index):
  assert index <= len(self), 'index range error'
  index += 1
  img_key = b'image-%09d' % index
  imgbuf = self.txn.get(img_key)
  #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
  buf = six.BytesIO()
  buf.write(imgbuf)
  buf.seek(0)
  try:
   img = Image.open(buf).convert('RGB')
   # img = Image.open(buf).convert('L')
   # img = img.convert('RGB')
  except IOError:
   print('Corrupted image for %d' % index)
   return self[index + 1]
  # reconition labels
  label_key = b'label-%09d' % index
  word = self.txn.get(label_key).decode()
  if self.lowercase:
   word = word.lower()
  ## fill with the padding token
  label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
  label_list = []
  for char in word:
   if char in self.char2id:
    label_list.append(self.char2id[char])
   else:
    ## add the unknown token
    print('{0} is out of vocabulary.'.format(char))
    label_list.append(self.char2id[self.UNKNOWN])
  ## add a stop token
  label_list = label_list + [self.char2id[self.EOS]]
  assert len(label_list) <= self.max_len
  label[:len(label_list)] = np.array(label_list)
  if len(label) <= 0:
   return self[index + 1]
  # label length
  label_len = len(label_list)
  if self.transform is not None:
   img = self.transform(img)
  return img, label, label_len

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python操作列表的常用方法分享
Feb 13 Python
基于Python实现通过微信搜索功能查看谁把你删除了
Jan 27 Python
详解Python 协程的详细用法使用和例子
Jun 15 Python
解决python3.5 正常安装 却不能直接使用Tkinter包的问题
Feb 22 Python
Python里字典的基本用法(包括嵌套字典)
Feb 27 Python
Python数据类型之Dict字典实例详解
May 07 Python
python如何解析配置文件并应用到项目中
Jun 27 Python
Python使用scipy模块实现一维卷积运算示例
Sep 05 Python
Python3实现飞机大战游戏
Apr 24 Python
Python Socket多线程并发原理及实现
Dec 11 Python
python中remove函数的踩坑记录
Jan 04 Python
python使用pygame创建精灵Sprite
Apr 06 Python
Python Gluon参数和模块命名操作教程
Dec 18 #Python
python turtle 绘制太极图的实例
Dec 18 #Python
Python使用gluon/mxnet模块实现的mnist手写数字识别功能完整示例
Dec 18 #Python
简单了解Python读取大文件代码实例
Dec 18 #Python
python 比较2张图片的相似度的方法示例
Dec 18 #Python
使用Python的Turtle库绘制森林的实例
Dec 18 #Python
python3 requests库实现多图片爬取教程
Dec 18 #Python
You might like
PHP学习笔记之二
2011/01/17 PHP
PHP仿盗链代码
2012/06/03 PHP
php处理静态页面:页面设置缓存时间实例
2017/06/22 PHP
thinkPHP框架实现类似java过滤器的简单方法示例
2018/09/05 PHP
jquery 双色表格实现代码
2009/12/08 Javascript
jquery.validate使用攻略 第一部
2010/07/01 Javascript
jQuery)扩展jQuery系列之一 模拟alert,confirm(一)
2010/12/04 Javascript
从盛大通行证上摘下来的身份证验证js代码
2011/01/11 Javascript
基于JQuery的类似新浪微博展示信息效果的代码
2012/07/23 Javascript
设置checkbox为只读(readOnly)的两种方式
2013/10/11 Javascript
setInterval()和setTimeout()的用法和区别示例介绍
2013/11/17 Javascript
Jquery方式获取iframe页面中的 Dom元素
2014/05/07 Javascript
Egret引擎开发指南之运行项目
2014/09/03 Javascript
解决ueditor jquery javascript 取值问题
2014/12/30 Javascript
JavaScript中用于生成随机数的Math.random()方法
2015/06/15 Javascript
jquery siblings获取同辈元素用法实例分析
2016/07/25 Javascript
JavaScript实现无穷滚动加载数据
2017/05/06 Javascript
使用requirejs模块化开发多页面一个入口js的使用方式
2017/06/14 Javascript
jQuery 控制文本框自动缩小字体填充
2017/06/16 jQuery
ES6新增的math,Number方法
2017/08/06 Javascript
JavaScript全屏和退出全屏事件总结(附代码)
2017/08/17 Javascript
javascript/jquery实现点击触发事件的方法分析
2019/11/11 jQuery
tracking.js实现前端人脸识别功能
2020/04/16 Javascript
vuex刷新后数据丢失的解决方法
2020/10/18 Javascript
nuxt 服务器渲染动态设置 title和seo关键字的操作
2020/11/05 Javascript
Python利用前序和中序遍历结果重建二叉树的方法
2016/04/27 Python
Python实现XML文件解析的示例代码
2018/02/05 Python
使用python Telnet远程登录执行程序的方法
2019/01/26 Python
python2爬取百度贴吧指定关键字和图片代码实例
2019/08/14 Python
IronPython连接MySQL的方法步骤
2019/12/27 Python
python批量修改xml属性的实现方式
2020/03/05 Python
python调用百度API实现人脸识别
2020/11/17 Python
CSS3之transition实现下划线的示例代码
2018/05/30 HTML / CSS
程序运行正确, 但退出时却"core dump"了,怎么回事
2014/02/19 面试题
任命书怎么写
2015/03/02 职场文书
Java详细解析==和equals的区别
2022/04/07 Java/Android