Python Pytorch查询图像的特征从集合或数据库中查找图像

随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加。亚马逊、阿里巴巴、Myntra等公司一直在大量利用图像检索技术。当然,只有当通常的信息检索技术失败时,图像检索才会开始工作。

Posted in Python onApril 09, 2022

随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加。

亚马逊、阿里巴巴、Myntra等公司一直在大量利用图像检索技术。当然,只有当通常的信息检索技术失败时,图像检索才会开始工作。

背景

图像检索的基本本质是根据查询图像的特征从集合或数据库中查找图像。

大多数情况下,这种特征是图像之间简单的视觉相似性。在一个复杂的问题中,这种特征可能是两幅图像在风格上的相似性,甚至是互补性。

由于原始形式的图像不会在基于像素的数据中反映这些特征,因此我们需要将这些像素数据转换为一个潜空间,在该空间中,图像的表示将反映这些特征。

一般来说,在潜空间中,任何两个相似的图像都会相互靠近,而不同的图像则会相隔很远。这是我们用来训练我们的模型的基本管理规则。一旦我们这样做,检索部分只需搜索潜在空间,在给定查询图像表示的潜在空间中拾取最近的图像。大多数情况下,它是在最近邻搜索的帮助下完成的。

因此,我们可以将我们的方法分为两部分:

  • 图像表现
  • 搜索

我们将在Oxford 102 Flowers数据集上解决这两个部分。

图像表现

我们将使用一种叫做暹罗模型的东西,它本身并不是一种全新的模型,而是一种训练模型的技术。大多数情况下,这是与triplet loss一起使用的。这个技术的基本组成部分是三元组。

三元组是3个独立的数据样本,比如A(锚点),B(阳性)和C(阴性);其中A和B相似或具有相似的特征(可能是同一类),而C与A和B都不相似。这三个样本共同构成了训练数据的一个单元——三元组。

注:任何图像检索任务的90%都体现在暹罗网络、triplet loss和三元组的创建中。如果你成功地完成了这些,那么整个努力的成功或多或少是有保证的。

首先,我们将创建管道的这个组件——数据。下面我们将在PyTorch中创建一个自定义数据集和数据加载器,它将从数据集中生成三元组。

class TripletData(Dataset):
    def __init__(self, path, transforms, split="train"):
 
        self.path = path
        self.split = split    # train or valid
        self.cats = 102       # number of categories
        self.transforms = transforms
 
        
    def __getitem__(self, idx):
 
        # our positive class for the triplet
        idx = str(idx%self.cats + 1)
 
        # choosing our pair of positive images (im1, im2)
        positives = os.listdir(os.path.join(self.path, idx))
        im1, im2 = random.sample(positives, 2)
 
        # choosing a negative class and negative image (im3)
        negative_cats = [str(x+1) for x in range(self.cats)]
        negative_cats.remove(idx)
        negative_cat = str(random.choice(negative_cats))
        negatives = os.listdir(os.path.join(self.path, negative_cat))
 
        im3 = random.choice(negatives)
 
        im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
 
        im1 = self.transforms(Image.open(im1))
 
        im2 = self.transforms(Image.open(im2))
 
        im3 = self.transforms(Image.open(im3))
 
        return [im1, im2, im3]
 
    
    # we'll put some value that we want since there can be far too many triplets possible
    # multiples of the number of images/ number of categories is a good choice
    def __len__(self):
        return self.cats*8
# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

现在我们有了数据,让我们转到暹罗网络。

暹罗网络给人的印象是2个或3个模型,但是它本身是一个单一的模型。所有这些模型共享权重,即只有一个模型。

Python Pytorch查询图像的特征从集合或数据库中查找图像

如前所述,将整个体系结构结合在一起的关键因素是triplet loss。triplet loss产生了一个目标函数,该函数迫使相似输入对(锚点和正)之间的距离小于不同输入对(锚点和负)之间的距离,并限定一定的阈值。

下面我们来看看triplet loss以及训练管道实现。

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
# Training
for epoch in range(epochs):
    
    model.train()
    epoch_loss = 0.0
    
    for data in tqdm(train_loader):
        
        optimizer.zero_grad()
        x1,x2,x3 = data
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))
 
    
    
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
 
# Training
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for data in tqdm(train_loader):
 
        optimizer.zero_grad()
        
        x1,x2,x3 = data
        
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))

到目前为止,我们的模型已经经过训练,可以将图像转换为一个嵌入空间。接下来,我们进入搜索部分。

搜索

我们可以很容易地使用Scikit Learn提供的最近邻搜索。我们将探索新的更好的东西,而不是走简单的路线。

我们将使用Faiss。这比最近的邻居要快得多,如果我们有大量的图像,这种速度上的差异会变得更加明显。

下面我们将演示如何在给定查询图像时,在存储的图像表示中搜索最近的图像。

#!pip install faiss-gpu
import faiss                            
faiss_index = faiss.IndexFlatL2(1000)   # build the index
 
# storing the image representations
im_indices = []
 
with torch.no_grad():
    for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
        
        im = Image.open(f)
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        preds = model(im)
        preds = np.array([preds[0].cpu().numpy()])
        faiss_index.add(preds) #add the representation to index
        im_indices.append(f)   #store the image name to find it later on
 
        
# Retrieval with a query image
with torch.no_grad():
    for f in os.listdir(PATH_TEST):
        
        # query/test image
        im = Image.open(os.path.join(PATH_TEST,f))
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        test_embed = model(im).cpu().numpy()
        
        _, I = faiss_index.search(test_embed, 5)
        print("Retrieved Image: {}".format(im_indices[I[0][0]]))

这涵盖了基于现代深度学习的图像检索,但不会使其变得太复杂。大多数检索问题都可以通过这个基本管道解决。

以上就是Python Pytorch学习之图像检索实践的详细内容,更多关于Python Pytorch图像检索的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python完全新手教程
Feb 08 Python
零基础写python爬虫之打包生成exe文件
Nov 06 Python
python统计日志ip访问数的方法
Jul 06 Python
深入解析Python中的集合类型操作符
Aug 19 Python
python通过socket实现多个连接并实现ssh功能详解
Nov 08 Python
python 自定义对象的打印方法
Jan 12 Python
简单了解python代码优化小技巧
Jul 08 Python
django框架forms组件用法实例详解
Dec 10 Python
pycharm部署、配置anaconda环境的教程
Mar 24 Python
Python创建自己的加密货币的示例
Mar 01 Python
详解python的内存分配机制
May 10 Python
Python 数据科学 Matplotlib图库详解
Jul 07 Python
Python实现科学占卜 让视频自动打码
Python自动化工具之实现Excel转Markdown表格
Python加密技术之RSA加密解密的实现
Apr 08 #Python
Python识别花卉种类鉴定网络热门植物并自动整理分类
请求模块urllib之PYTHON爬虫的基本使用
用Python仅20行代码编写一个简单的端口扫描器
Python实现视频自动打码的示例代码
Apr 08 #Python
You might like
PHP 事务处理数据实现代码
2010/05/13 PHP
php安全配置 如何配置使其更安全
2011/12/16 PHP
php匹配字符中链接地址的方法
2014/12/22 PHP
php支持中文字符串分割的函数
2015/05/28 PHP
Linux系统递归生成目录中文件的md5的方法
2015/06/29 PHP
浅析PHP7新功能及语法变化总结
2016/06/17 PHP
JavaScript中的对象化编程
2008/01/16 Javascript
使用jquery给input和textarea设定ie中的focus
2008/05/29 Javascript
css值转换成数值请抛弃parseInt
2011/10/24 Javascript
js导航菜单(自写)简单大方
2013/03/28 Javascript
快速查找数组中的某个元素并返回下标示例
2013/09/03 Javascript
JS实现让网页背景图片斜向移动的方法
2015/02/25 Javascript
详谈innerHTML innerText的使用和区别
2017/08/18 Javascript
讲解vue-router之什么是编程式路由
2018/05/28 Javascript
jQuery+Datatables实现表格批量删除功能【推荐】
2018/10/24 jQuery
深入浅析Vue.js 中的 v-for 列表渲染指令
2018/11/19 Javascript
vue动态添加路由addRoutes之不能将动态路由存入缓存的解决
2019/02/19 Javascript
js防抖函数和节流函数使用场景和实现区别示例分析
2020/04/11 Javascript
微信小程序tab左右滑动切换功能的实现代码
2021/02/08 Javascript
[00:16]热血竞技场
2019/03/06 DOTA
python实现计算资源图标crc值的方法
2014/10/05 Python
Python利用matplotlib生成图片背景及图例透明的效果
2017/04/27 Python
python负载均衡的简单实现方法
2018/02/04 Python
查看python下OpenCV版本的方法
2018/08/03 Python
CSS3媒体查询Media Queries基础学习教程
2016/02/29 HTML / CSS
逼真的HTML5树叶飘落动画
2016/03/01 HTML / CSS
美国婴儿和儿童家具网上商店:ABaby.com
2018/07/02 全球购物
荷兰街头时尚之家:Funkie House
2019/03/18 全球购物
生产副总岗位职责
2013/11/28 职场文书
留学自荐信写作方法
2014/01/27 职场文书
小学生个人先进事迹材料
2014/05/08 职场文书
乡党政领导班子群众路线教育实践活动个人对照检查材料
2014/09/20 职场文书
病人写给医生的感谢信
2015/01/23 职场文书
2015年转正工作总结范文
2015/04/02 职场文书
如何正确理解python装饰器
2021/06/15 Python
Java中try catch处理异常示例
2021/12/06 Java/Android