Python Pytorch查询图像的特征从集合或数据库中查找图像

随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加。亚马逊、阿里巴巴、Myntra等公司一直在大量利用图像检索技术。当然,只有当通常的信息检索技术失败时,图像检索才会开始工作。

Posted in Python onApril 09, 2022

随着电子商务和在线网站的出现,图像检索在我们的日常生活中的应用一直在增加。

亚马逊、阿里巴巴、Myntra等公司一直在大量利用图像检索技术。当然,只有当通常的信息检索技术失败时,图像检索才会开始工作。

背景

图像检索的基本本质是根据查询图像的特征从集合或数据库中查找图像。

大多数情况下,这种特征是图像之间简单的视觉相似性。在一个复杂的问题中,这种特征可能是两幅图像在风格上的相似性,甚至是互补性。

由于原始形式的图像不会在基于像素的数据中反映这些特征,因此我们需要将这些像素数据转换为一个潜空间,在该空间中,图像的表示将反映这些特征。

一般来说,在潜空间中,任何两个相似的图像都会相互靠近,而不同的图像则会相隔很远。这是我们用来训练我们的模型的基本管理规则。一旦我们这样做,检索部分只需搜索潜在空间,在给定查询图像表示的潜在空间中拾取最近的图像。大多数情况下,它是在最近邻搜索的帮助下完成的。

因此,我们可以将我们的方法分为两部分:

  • 图像表现
  • 搜索

我们将在Oxford 102 Flowers数据集上解决这两个部分。

图像表现

我们将使用一种叫做暹罗模型的东西,它本身并不是一种全新的模型,而是一种训练模型的技术。大多数情况下,这是与triplet loss一起使用的。这个技术的基本组成部分是三元组。

三元组是3个独立的数据样本,比如A(锚点),B(阳性)和C(阴性);其中A和B相似或具有相似的特征(可能是同一类),而C与A和B都不相似。这三个样本共同构成了训练数据的一个单元——三元组。

注:任何图像检索任务的90%都体现在暹罗网络、triplet loss和三元组的创建中。如果你成功地完成了这些,那么整个努力的成功或多或少是有保证的。

首先,我们将创建管道的这个组件——数据。下面我们将在PyTorch中创建一个自定义数据集和数据加载器,它将从数据集中生成三元组。

class TripletData(Dataset):
    def __init__(self, path, transforms, split="train"):
 
        self.path = path
        self.split = split    # train or valid
        self.cats = 102       # number of categories
        self.transforms = transforms
 
        
    def __getitem__(self, idx):
 
        # our positive class for the triplet
        idx = str(idx%self.cats + 1)
 
        # choosing our pair of positive images (im1, im2)
        positives = os.listdir(os.path.join(self.path, idx))
        im1, im2 = random.sample(positives, 2)
 
        # choosing a negative class and negative image (im3)
        negative_cats = [str(x+1) for x in range(self.cats)]
        negative_cats.remove(idx)
        negative_cat = str(random.choice(negative_cats))
        negatives = os.listdir(os.path.join(self.path, negative_cat))
 
        im3 = random.choice(negatives)
 
        im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
 
        im1 = self.transforms(Image.open(im1))
 
        im2 = self.transforms(Image.open(im2))
 
        im3 = self.transforms(Image.open(im3))
 
        return [im1, im2, im3]
 
    
    # we'll put some value that we want since there can be far too many triplets possible
    # multiples of the number of images/ number of categories is a good choice
    def __len__(self):
        return self.cats*8
# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

现在我们有了数据,让我们转到暹罗网络。

暹罗网络给人的印象是2个或3个模型,但是它本身是一个单一的模型。所有这些模型共享权重,即只有一个模型。

Python Pytorch查询图像的特征从集合或数据库中查找图像

如前所述,将整个体系结构结合在一起的关键因素是triplet loss。triplet loss产生了一个目标函数,该函数迫使相似输入对(锚点和正)之间的距离小于不同输入对(锚点和负)之间的距离,并限定一定的阈值。

下面我们来看看triplet loss以及训练管道实现。

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
# Training
for epoch in range(epochs):
    
    model.train()
    epoch_loss = 0.0
    
    for data in tqdm(train_loader):
        
        optimizer.zero_grad()
        x1,x2,x3 = data
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))
 
    
    
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
 
# Training
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for data in tqdm(train_loader):
 
        optimizer.zero_grad()
        
        x1,x2,x3 = data
        
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))

到目前为止,我们的模型已经经过训练,可以将图像转换为一个嵌入空间。接下来,我们进入搜索部分。

搜索

我们可以很容易地使用Scikit Learn提供的最近邻搜索。我们将探索新的更好的东西,而不是走简单的路线。

我们将使用Faiss。这比最近的邻居要快得多,如果我们有大量的图像,这种速度上的差异会变得更加明显。

下面我们将演示如何在给定查询图像时,在存储的图像表示中搜索最近的图像。

#!pip install faiss-gpu
import faiss                            
faiss_index = faiss.IndexFlatL2(1000)   # build the index
 
# storing the image representations
im_indices = []
 
with torch.no_grad():
    for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
        
        im = Image.open(f)
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        preds = model(im)
        preds = np.array([preds[0].cpu().numpy()])
        faiss_index.add(preds) #add the representation to index
        im_indices.append(f)   #store the image name to find it later on
 
        
# Retrieval with a query image
with torch.no_grad():
    for f in os.listdir(PATH_TEST):
        
        # query/test image
        im = Image.open(os.path.join(PATH_TEST,f))
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        test_embed = model(im).cpu().numpy()
        
        _, I = faiss_index.search(test_embed, 5)
        print("Retrieved Image: {}".format(im_indices[I[0][0]]))

这涵盖了基于现代深度学习的图像检索,但不会使其变得太复杂。大多数检索问题都可以通过这个基本管道解决。

以上就是Python Pytorch学习之图像检索实践的详细内容,更多关于Python Pytorch图像检索的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
jupyter安装小结
Mar 13 Python
Python-嵌套列表list的全面解析
Jun 08 Python
解决pyqt中ui编译成窗体.py中文乱码的问题
Dec 23 Python
itchat接口使用示例
Oct 23 Python
Python带动态参数功能的sqlite工具类
May 26 Python
Django2.1.3 中间件使用详解
Nov 26 Python
解决Djang2.0.1中的reverse导入失败的问题
Aug 16 Python
python爬虫之爬取百度音乐的实现方法
Aug 24 Python
Python使用Chrome插件实现爬虫过程图解
Jun 09 Python
详解Python的爬虫框架 Scrapy
Aug 03 Python
python字典与json转换的方法总结
Dec 28 Python
详解运行Python的神器Jupyter Notebook
Jun 03 Python
Python实现科学占卜 让视频自动打码
Python自动化工具之实现Excel转Markdown表格
Python加密技术之RSA加密解密的实现
Apr 08 #Python
Python识别花卉种类鉴定网络热门植物并自动整理分类
请求模块urllib之PYTHON爬虫的基本使用
用Python仅20行代码编写一个简单的端口扫描器
Python实现视频自动打码的示例代码
Apr 08 #Python
You might like
php数据结构 算法(PHP描述) 简单选择排序 simple selection sort
2011/08/09 PHP
2014最热门的24个php类库汇总
2014/12/18 PHP
PHP的伪随机数与真随机数详解
2015/05/27 PHP
PHP基于curl后台远程登录正方教务系统的方法
2016/10/14 PHP
详解PHP编码转换函数应用技巧
2016/10/22 PHP
PHP使用gearman进行异步的邮件或短信发送操作详解
2020/02/27 PHP
向大师们学习Javascript(视频与PPT)
2009/12/27 Javascript
javascript如何写热点图
2015/12/08 Javascript
Node.js获取前端ajax提交的request信息
2017/02/20 Javascript
vue安装和使用scss及sass与scss的区别详解
2018/10/15 Javascript
记录一次完整的react hooks实践
2019/03/11 Javascript
微信小程序BindTap快速连续点击目标页面跳转多次问题处理
2019/04/08 Javascript
JavaScript单线程和任务队列原理解析
2020/02/04 Javascript
antd-DatePicker组件获取时间值,及相关设置方式
2020/10/27 Javascript
jQuery实现简单轮播图效果
2020/12/27 jQuery
pycharm 使用心得(三)Hello world!
2014/06/05 Python
python网络编程学习笔记(七):HTML和XHTML解析(HTMLParser、BeautifulSoup)
2014/06/09 Python
Centos5.x下升级python到python2.7版本教程
2015/02/14 Python
对python的文件内注释 help注释方法
2018/05/23 Python
python json.loads兼容单引号数据的方法
2018/12/19 Python
Python实现针对json中某个关键字段进行排序操作示例
2018/12/25 Python
matplotlib.pyplot绘图显示控制方法
2019/01/15 Python
python使用thrift教程的方法示例
2019/03/21 Python
python并发编程多进程 模拟抢票实现过程
2019/08/20 Python
python数据分析:关键字提取方式
2020/02/24 Python
学点简单的Django之第一个Django程序的实现
2021/02/24 Python
HTML5时代CSS设置漂亮字体取代图片
2014/09/04 HTML / CSS
html5录音功能实战示例
2019/03/25 HTML / CSS
Under Armour澳大利亚官网:美国知名的高端功能性运动品牌
2018/02/22 全球购物
手工制作的意大利礼服鞋:Ace Marks
2018/12/15 全球购物
限量版运动鞋和街头服饰:TheDrop
2020/09/06 全球购物
洗车工岗位职责
2014/03/15 职场文书
党代会心得体会
2014/09/04 职场文书
个人廉洁自律总结
2015/03/06 职场文书
优质服务标语口号
2015/12/26 职场文书
Redis特殊数据类型HyperLogLog基数统计算法讲解
2022/06/01 Redis