Python Pytorch查询图像的特征从集合或数据库中查找图像

随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加。亚马逊、阿里巴巴、Myntra等公司一直在大量利用图像检索技术。当然,只有当通常的信息检索技术失败时,图像检索才会开始工作。

Posted in Python onApril 09, 2022

随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加。

亚马逊、阿里巴巴、Myntra等公司一直在大量利用图像检索技术。当然,只有当通常的信息检索技术失败时,图像检索才会开始工作。

背景

图像检索的基本本质是根据查询图像的特征从集合或数据库中查找图像。

大多数情况下,这种特征是图像之间简单的视觉相似性。在一个复杂的问题中,这种特征可能是两幅图像在风格上的相似性,甚至是互补性。

由于原始形式的图像不会在基于像素的数据中反映这些特征,因此我们需要将这些像素数据转换为一个潜空间,在该空间中,图像的表示将反映这些特征。

一般来说,在潜空间中,任何两个相似的图像都会相互靠近,而不同的图像则会相隔很远。这是我们用来训练我们的模型的基本管理规则。一旦我们这样做,检索部分只需搜索潜在空间,在给定查询图像表示的潜在空间中拾取最近的图像。大多数情况下,它是在最近邻搜索的帮助下完成的。

因此,我们可以将我们的方法分为两部分:

  • 图像表现
  • 搜索

我们将在Oxford 102 Flowers数据集上解决这两个部分。

图像表现

我们将使用一种叫做暹罗模型的东西,它本身并不是一种全新的模型,而是一种训练模型的技术。大多数情况下,这是与triplet loss一起使用的。这个技术的基本组成部分是三元组。

三元组是3个独立的数据样本,比如A(锚点),B(阳性)和C(阴性);其中A和B相似或具有相似的特征(可能是同一类),而C与A和B都不相似。这三个样本共同构成了训练数据的一个单元——三元组。

注:任何图像检索任务的90%都体现在暹罗网络、triplet loss和三元组的创建中。如果你成功地完成了这些,那么整个努力的成功或多或少是有保证的。

首先,我们将创建管道的这个组件——数据。下面我们将在PyTorch中创建一个自定义数据集和数据加载器,它将从数据集中生成三元组。

class TripletData(Dataset):
    def __init__(self, path, transforms, split="train"):
 
        self.path = path
        self.split = split    # train or valid
        self.cats = 102       # number of categories
        self.transforms = transforms
 
        
    def __getitem__(self, idx):
 
        # our positive class for the triplet
        idx = str(idx%self.cats + 1)
 
        # choosing our pair of positive images (im1, im2)
        positives = os.listdir(os.path.join(self.path, idx))
        im1, im2 = random.sample(positives, 2)
 
        # choosing a negative class and negative image (im3)
        negative_cats = [str(x+1) for x in range(self.cats)]
        negative_cats.remove(idx)
        negative_cat = str(random.choice(negative_cats))
        negatives = os.listdir(os.path.join(self.path, negative_cat))
 
        im3 = random.choice(negatives)
 
        im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
 
        im1 = self.transforms(Image.open(im1))
 
        im2 = self.transforms(Image.open(im2))
 
        im3 = self.transforms(Image.open(im3))
 
        return [im1, im2, im3]
 
    
    # we'll put some value that we want since there can be far too many triplets possible
    # multiples of the number of images/ number of categories is a good choice
    def __len__(self):
        return self.cats*8
# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

现在我们有了数据,让我们转到暹罗网络。

暹罗网络给人的印象是2个或3个模型,但是它本身是一个单一的模型。所有这些模型共享权重,即只有一个模型。

Python Pytorch查询图像的特征从集合或数据库中查找图像

如前所述,将整个体系结构结合在一起的关键因素是triplet loss。triplet loss产生了一个目标函数,该函数迫使相似输入对(锚点和正)之间的距离小于不同输入对(锚点和负)之间的距离,并限定一定的阈值。

下面我们来看看triplet loss以及训练管道实现。

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
# Training
for epoch in range(epochs):
    
    model.train()
    epoch_loss = 0.0
    
    for data in tqdm(train_loader):
        
        optimizer.zero_grad()
        x1,x2,x3 = data
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))
 
    
    
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
 
# Training
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for data in tqdm(train_loader):
 
        optimizer.zero_grad()
        
        x1,x2,x3 = data
        
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))

到目前为止,我们的模型已经经过训练,可以将图像转换为一个嵌入空间。接下来,我们进入搜索部分。

搜索

我们可以很容易地使用Scikit Learn提供的最近邻搜索。我们将探索新的更好的东西,而不是走简单的路线。

我们将使用Faiss。这比最近的邻居要快得多,如果我们有大量的图像,这种速度上的差异会变得更加明显。

下面我们将演示如何在给定查询图像时,在存储的图像表示中搜索最近的图像。

#!pip install faiss-gpu
import faiss                            
faiss_index = faiss.IndexFlatL2(1000)   # build the index
 
# storing the image representations
im_indices = []
 
with torch.no_grad():
    for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
        
        im = Image.open(f)
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        preds = model(im)
        preds = np.array([preds[0].cpu().numpy()])
        faiss_index.add(preds) #add the representation to index
        im_indices.append(f)   #store the image name to find it later on
 
        
# Retrieval with a query image
with torch.no_grad():
    for f in os.listdir(PATH_TEST):
        
        # query/test image
        im = Image.open(os.path.join(PATH_TEST,f))
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        test_embed = model(im).cpu().numpy()
        
        _, I = faiss_index.search(test_embed, 5)
        print("Retrieved Image: {}".format(im_indices[I[0][0]]))

这涵盖了基于现代深度学习的图像检索,但不会使其变得太复杂。大多数检索问题都可以通过这个基本管道解决。

以上就是Python Pytorch学习之图像检索实践的详细内容,更多关于Python Pytorch图像检索的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
详细解读Python中解析XML数据的方法
Oct 15 Python
python将.ppm格式图片转换成.jpg格式文件的方法
Oct 27 Python
BP神经网络原理及Python实现代码
Dec 18 Python
PyQt5实现类似别踩白块游戏
Jan 24 Python
python字典改变value值方法总结
Jun 21 Python
解决Python中pandas读取*.csv文件出现编码问题
Jul 12 Python
使用 Django Highcharts 实现数据可视化过程解析
Jul 31 Python
Python Django 页面上展示固定的页码数实现代码
Aug 21 Python
Python 操作mysql数据库查询之fetchone(), fetchmany(), fetchall()用法示例
Oct 17 Python
Python csv文件的读写操作实例详解
Nov 19 Python
python下对hsv颜色空间进行量化操作
Jun 04 Python
pytorch 带batch的tensor类型图像显示操作
May 20 Python
Python实现科学占卜 让视频自动打码
Python自动化工具之实现Excel转Markdown表格
Python加密技术之RSA加密解密的实现
Apr 08 #Python
Python识别花卉种类鉴定网络热门植物并自动整理分类
请求模块urllib之PYTHON爬虫的基本使用
用Python仅20行代码编写一个简单的端口扫描器
Python实现视频自动打码的示例代码
Apr 08 #Python
You might like
火车头discuz6.1 完美采集的php接口文件
2009/09/13 PHP
深入理解PHP类的自动载入机制
2016/09/16 PHP
PHP 中使用explode()函数切割字符串为数组的示例
2017/05/06 PHP
PHP XML Expat解析器知识点总结
2019/02/15 PHP
JQuery UI DatePicker中z-index默认为1的解决办法
2010/09/28 Javascript
基于jquery的分页控件(C#)
2011/01/06 Javascript
基于PHP+Jquery制作的可编辑的表格的代码
2011/04/10 Javascript
zTree插件之多选下拉菜单实例代码
2013/11/06 Javascript
js判读浏览器是否支持html5的canvas的代码
2013/11/18 Javascript
jquery获取checkbox的值并post提交
2015/01/14 Javascript
javascript中checkbox使用方法实例演示
2015/11/19 Javascript
浅谈json取值(对象和数组)
2016/06/24 Javascript
原生js实现吸顶效果
2017/03/13 Javascript
Angular如何在应用初始化时运行代码详解
2018/06/11 Javascript
使用 js 简单的实现 bind、call 、aplly代码实例
2019/09/07 Javascript
Vue2.0 $set()的正确使用详解
2020/07/28 Javascript
[54:19]完美世界DOTA2联赛PWL S2 Magma vs PXG 第二场 11.28
2020/12/01 DOTA
python实现哈希表
2014/02/07 Python
python测试驱动开发实例
2014/10/08 Python
Python中使用md5sum检查目录中相同文件代码分享
2015/02/02 Python
python统计文本文件内单词数量的方法
2015/05/30 Python
Python编程中的文件读写及相关的文件对象方法讲解
2016/01/19 Python
Python3中的真除和Floor除法用法分析
2016/03/16 Python
python实现比较文件内容异同
2018/06/22 Python
pytorch: tensor类型的构建与相互转换实例
2018/07/26 Python
python中字符串数组逆序排列方法总结
2019/06/23 Python
python中wx模块的具体使用方法
2020/05/15 Python
Keras框架中的epoch、bacth、batch size、iteration使用介绍
2020/06/10 Python
Python将字典转换为XML的方法
2020/08/01 Python
中式结婚主持词
2014/03/14 职场文书
2014年园林绿化工作总结
2014/12/11 职场文书
2014幼儿园教育教学工作总结
2014/12/17 职场文书
2015年教研室工作总结范文
2015/05/23 职场文书
tp5使用layui实现多个图片上传(带附件选择)的方法实例
2021/11/17 PHP
什么是SOLID
2022/03/24 Javascript
js前端设计模式优化50%表单校验代码示例
2022/06/21 Javascript