keras-siamese用自己的数据集实现详解


Posted in Python onJune 10, 2020

Siamese网络不做过多介绍,思想并不难,输入两个图像,输出这两张图像的相似度,两个输入的网络结构是相同的,参数共享。

主要发现很多代码都是基于mnist数据集的,下面说一下怎么用自己的数据集实现siamese网络。

首先,先整理数据集,相同的类放到同一个文件夹下,如下图所示:

keras-siamese用自己的数据集实现详解

接下来,将pairs及对应的label写到csv中,代码如下:

import os
import random
import csv
#图片所在的路径
path = '/Users/mac/Desktop/wxd/flag/category/'
#files列表保存所有类别的路径
files=[]
same_pairs=[]
different_pairs=[]
for file in os.listdir(path):
 if file[0]=='.':
  continue
 file_path = os.path.join(path,file)
 files.append(file_path)
#该地址为csv要保存到的路径,a表示追加写入
with open('/Users/mac/Desktop/wxd/flag/data.csv','a') as f:
 #保存相同对
 writer = csv.writer(f)
 for file in files:
  imgs = os.listdir(file) 
  for i in range(0,len(imgs)-1):
   for j in range(i+1,len(imgs)):
    pairs = []
    name = file.split(sep='/')[-1]
    pairs.append(path+name+'/'+imgs[i])
    pairs.append(path+name+'/'+imgs[j])
    pairs.append(1)
    writer.writerow(pairs)
 #保存不同对
 for i in range(0,len(files)-1):
  for j in range(i+1,len(files)):
   filea = files[i]
   fileb = files[j]
   imga_li = os.listdir(filea)
   imgb_li = os.listdir(fileb)
   random.shuffle(imga_li)
   random.shuffle(imgb_li)
   a_li = imga_li[:]
   b_li = imgb_li[:]
   for p in range(len(a_li)):
    for q in range(len(b_li)):
     pairs = []
     name1 = filea.split(sep='/')[-1]
     name2 = fileb.split(sep='/')[-1]
     pairs.append(path+name1+'/'+a_li[p])
     pairs.append(path+name2+'/'+b_li[q])
     pairs.append(0)
     writer.writerow(pairs)

相当于csv每一行都包含一对结果,每一行有三列,第一列第一张图片路径,第二列第二张图片路径,第三列是不是相同的label,属于同一个类的label为1,不同类的为0,可参考下图:

keras-siamese用自己的数据集实现详解

然后,由于keras的fit函数需要将训练数据都塞入内存,而大部分训练数据都较大,因此才用fit_generator生成器的方法,便可以训练大数据,代码如下:

from __future__ import absolute_import
from __future__ import print_function
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \
 Activation, ZeroPadding2D
from keras.layers import add, Flatten
from keras.utils import plot_model
from keras.metrics import top_k_categorical_accuracy
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
import tensorflow as tf
import random
import os
import cv2
import csv
import numpy as np
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import img_to_array
 
"""
自定义的参数
"""
im_width = 224
im_height = 224
epochs = 100
batch_size = 64
iterations = 1000
csv_path = ''
model_result = ''
 
 
# 计算欧式距离
def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))
 
def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)
 
# 计算loss
def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 square_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
 
def compute_accuracy(y_true, y_pred):
 '''计算准确率
 '''
 pred = y_pred.ravel() < 0.5
 print('pred:', pred)
 return np.mean(pred == y_true)
 
def accuracy(y_true, y_pred):
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
 
def processImg(filename):
 """
 :param filename: 图像的路径
 :return: 返回的是归一化矩阵
 """
 img = cv2.imread(filename)
 img = cv2.resize(img, (im_width, im_height))
 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 img = img_to_array(img)
 img /= 255
 return img
 
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None):
 if name is not None:
  bn_name = name + '_bn'
  conv_name = name + '_conv'
 else:
  bn_name = None
  conv_name = None
 
 x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x)
 x = BatchNormalization(axis=3, name=bn_name)(x)
 return x
 
def bottleneck_Block(inpt, nb_filters, strides=(1, 1), with_conv_shortcut=False):
 k1, k2, k3 = nb_filters
 x = Conv2d_BN(inpt, nb_filter=k1, kernel_size=1, strides=strides, padding='same')
 x = Conv2d_BN(x, nb_filter=k2, kernel_size=3, padding='same')
 x = Conv2d_BN(x, nb_filter=k3, kernel_size=1, padding='same')
 if with_conv_shortcut:
  shortcut = Conv2d_BN(inpt, nb_filter=k3, strides=strides, kernel_size=1)
  x = add([x, shortcut])
  return x
 else:
  x = add([x, inpt])
  return x
 
def resnet_50():
 width = im_width
 height = im_height
 channel = 3
 inpt = Input(shape=(width, height, channel))
 x = ZeroPadding2D((3, 3))(inpt)
 x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
 x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
 
 # conv2_x
 x = bottleneck_Block(x, nb_filters=[64, 64, 256], strides=(1, 1), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 
 # conv3_x
 x = bottleneck_Block(x, nb_filters=[128, 128, 512], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 
 # conv4_x
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 
 # conv5_x
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 
 x = AveragePooling2D(pool_size=(7, 7))(x)
 x = Flatten()(x)
 x = Dense(128, activation='relu')(x)
 return Model(inpt, x)
 
def generator(imgs, batch_size):
 """
 自定义迭代器
 :param imgs: 列表,每个包含一对矩阵以及label
 :param batch_size:
 :return:
 """
 while 1:
  random.shuffle(imgs)
  li = imgs[:batch_size]
  pairs = []
  labels = []
  for i in li:
   img1 = i[0]
   img2 = i[1]
   im1 = cv2.imread(img1)
   im2 = cv2.imread(img2)
   if im1 is None or im2 is None:
    continue
   label = int(i[2])
   img1 = processImg(img1)
   img2 = processImg(img2)
   pairs.append([img1, img2])
   labels.append(label)
  pairs = np.array(pairs)
  labels = np.array(labels)
  yield [pairs[:, 0], pairs[:, 1]], labels
 
input_shape = (im_width, im_height, 3)
base_network = resnet_50()
 
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
 
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
 
distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])
with tf.device("/gpu:0"):
 model = Model([input_a, input_b], distance)
 # train
 rms = RMSprop()
 rows = csv.reader(open(csv_path, 'r'), delimiter=',')
 imgs = list(rows)
 checkpoint = ModelCheckpoint(filepath=model_result+'flag_{epoch:03d}.h5', verbose=1)
 model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
 model.fit_generator(generator(imgs, batch_size), epochs=epochs, steps_per_epoch=iterations, callbacks=[checkpoint])

用了回调函数保存了每一个epoch后的模型,也可以保存最好的,之后需要对模型进行测试。

测试时直接用load_model会报错,而应该变成如下形式调用:

model = load_model(model_path,custom_objects={'contrastive_loss': contrastive_loss }) #选取自己的.h模型名称

emmm,到这里,就成功训练测试完了~~~写的比较粗,因为这个代码在官方给的mnist上的改动不大,只是方便大家用自己的数据集,大家如果有更好的方法可以提出意见~~~希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现从url中提取域名的几种方法
Sep 26 Python
Python的装饰器用法学习笔记
Jun 24 Python
将Dataframe数据转化为ndarry数据的方法
Jun 28 Python
详解python列表(list)的使用技巧及高级操作
Aug 15 Python
PyQt+socket实现远程操作服务器的方法示例
Aug 22 Python
Jupyter notebook无法导入第三方模块的解决方式
Apr 15 Python
python同时遍历两个list用法说明
May 02 Python
通过实例解析python创建进程常用方法
Jun 19 Python
如何解决安装python3.6.1失败
Jul 01 Python
python怎么调用自己的函数
Jul 01 Python
Python实现简单的2048小游戏
Mar 01 Python
Python torch.flatten()函数案例详解
Aug 30 Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
You might like
精致的人儿就要挑杯子喝咖啡
2021/03/03 冲泡冲煮
PHP采集类snoopy详细介绍(snoopy使用教程)
2014/06/19 PHP
Yii2实现跨mysql数据库关联查询排序功能代码
2017/02/10 PHP
PHP对象相关知识总结
2017/04/09 PHP
PHP实现基于回溯法求解迷宫问题的方法详解
2017/08/17 PHP
PHP实现页面静态化深入讲解
2021/03/04 PHP
JS 自动安装exe程序
2008/11/30 Javascript
分享精心挑选的23款美轮美奂的jQuery 图片特效插件
2012/08/14 Javascript
jQuery ui插件的使用方法代码实例
2013/05/08 Javascript
javascript中的this详解
2014/12/08 Javascript
浅谈JavaScript中变量和函数声明的提升
2016/08/09 Javascript
通过示例彻底搞懂js闭包
2017/08/10 Javascript
electron + vue项目实现打印小票功能及实现代码
2018/11/25 Javascript
JS温故而知新之变量提升和时间死区
2019/01/27 Javascript
浅谈vue 锚点指令v-anchor的使用
2019/11/13 Javascript
微信小程序静默登录的实现代码
2020/01/08 Javascript
收集的几个Python小技巧分享
2014/11/22 Python
CentOS 7下安装Python 3.5并与Python2.7兼容并存详解
2017/07/07 Python
在python中使用xlrd获取合并单元格的方法
2018/12/26 Python
使用python接入微信聊天机器人
2020/03/31 Python
如何基于Python批量下载音乐
2019/11/11 Python
关于python的缩进规则的知识点详解
2020/06/22 Python
Python如何急速下载第三方库详解
2020/11/02 Python
解决python3.6用cx_Oracle库连接Oracle的问题
2020/12/07 Python
CSS3制作苹果风格键盘特效
2015/02/26 HTML / CSS
CSS3 实现发光边框特效
2020/11/11 HTML / CSS
吉力贝官方网站:Jelly Belly
2019/03/11 全球购物
Linux管理员面试经常问道的相关命令
2014/12/12 面试题
企业优秀员工事迹材料
2014/05/28 职场文书
2014年党员创先争优承诺书
2014/05/29 职场文书
知识竞赛拉拉队口号
2014/06/16 职场文书
2014年党员评议表自我评价
2014/09/27 职场文书
2015年12.4全国法制宣传日活动总结
2015/03/24 职场文书
人口与计划生育责任书
2015/05/09 职场文书
幼儿园家长心得体会
2016/01/21 职场文书
Python访问Redis的详细操作
2021/06/26 Python