keras-siamese用自己的数据集实现详解


Posted in Python onJune 10, 2020

Siamese网络不做过多介绍,思想并不难,输入两个图像,输出这两张图像的相似度,两个输入的网络结构是相同的,参数共享。

主要发现很多代码都是基于mnist数据集的,下面说一下怎么用自己的数据集实现siamese网络。

首先,先整理数据集,相同的类放到同一个文件夹下,如下图所示:

keras-siamese用自己的数据集实现详解

接下来,将pairs及对应的label写到csv中,代码如下:

import os
import random
import csv
#图片所在的路径
path = '/Users/mac/Desktop/wxd/flag/category/'
#files列表保存所有类别的路径
files=[]
same_pairs=[]
different_pairs=[]
for file in os.listdir(path):
 if file[0]=='.':
  continue
 file_path = os.path.join(path,file)
 files.append(file_path)
#该地址为csv要保存到的路径,a表示追加写入
with open('/Users/mac/Desktop/wxd/flag/data.csv','a') as f:
 #保存相同对
 writer = csv.writer(f)
 for file in files:
  imgs = os.listdir(file) 
  for i in range(0,len(imgs)-1):
   for j in range(i+1,len(imgs)):
    pairs = []
    name = file.split(sep='/')[-1]
    pairs.append(path+name+'/'+imgs[i])
    pairs.append(path+name+'/'+imgs[j])
    pairs.append(1)
    writer.writerow(pairs)
 #保存不同对
 for i in range(0,len(files)-1):
  for j in range(i+1,len(files)):
   filea = files[i]
   fileb = files[j]
   imga_li = os.listdir(filea)
   imgb_li = os.listdir(fileb)
   random.shuffle(imga_li)
   random.shuffle(imgb_li)
   a_li = imga_li[:]
   b_li = imgb_li[:]
   for p in range(len(a_li)):
    for q in range(len(b_li)):
     pairs = []
     name1 = filea.split(sep='/')[-1]
     name2 = fileb.split(sep='/')[-1]
     pairs.append(path+name1+'/'+a_li[p])
     pairs.append(path+name2+'/'+b_li[q])
     pairs.append(0)
     writer.writerow(pairs)

相当于csv每一行都包含一对结果,每一行有三列,第一列第一张图片路径,第二列第二张图片路径,第三列是不是相同的label,属于同一个类的label为1,不同类的为0,可参考下图:

keras-siamese用自己的数据集实现详解

然后,由于keras的fit函数需要将训练数据都塞入内存,而大部分训练数据都较大,因此才用fit_generator生成器的方法,便可以训练大数据,代码如下:

from __future__ import absolute_import
from __future__ import print_function
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \
 Activation, ZeroPadding2D
from keras.layers import add, Flatten
from keras.utils import plot_model
from keras.metrics import top_k_categorical_accuracy
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
import tensorflow as tf
import random
import os
import cv2
import csv
import numpy as np
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import img_to_array
 
"""
自定义的参数
"""
im_width = 224
im_height = 224
epochs = 100
batch_size = 64
iterations = 1000
csv_path = ''
model_result = ''
 
 
# 计算欧式距离
def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))
 
def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)
 
# 计算loss
def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 square_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
 
def compute_accuracy(y_true, y_pred):
 '''计算准确率
 '''
 pred = y_pred.ravel() < 0.5
 print('pred:', pred)
 return np.mean(pred == y_true)
 
def accuracy(y_true, y_pred):
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
 
def processImg(filename):
 """
 :param filename: 图像的路径
 :return: 返回的是归一化矩阵
 """
 img = cv2.imread(filename)
 img = cv2.resize(img, (im_width, im_height))
 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 img = img_to_array(img)
 img /= 255
 return img
 
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None):
 if name is not None:
  bn_name = name + '_bn'
  conv_name = name + '_conv'
 else:
  bn_name = None
  conv_name = None
 
 x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x)
 x = BatchNormalization(axis=3, name=bn_name)(x)
 return x
 
def bottleneck_Block(inpt, nb_filters, strides=(1, 1), with_conv_shortcut=False):
 k1, k2, k3 = nb_filters
 x = Conv2d_BN(inpt, nb_filter=k1, kernel_size=1, strides=strides, padding='same')
 x = Conv2d_BN(x, nb_filter=k2, kernel_size=3, padding='same')
 x = Conv2d_BN(x, nb_filter=k3, kernel_size=1, padding='same')
 if with_conv_shortcut:
  shortcut = Conv2d_BN(inpt, nb_filter=k3, strides=strides, kernel_size=1)
  x = add([x, shortcut])
  return x
 else:
  x = add([x, inpt])
  return x
 
def resnet_50():
 width = im_width
 height = im_height
 channel = 3
 inpt = Input(shape=(width, height, channel))
 x = ZeroPadding2D((3, 3))(inpt)
 x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
 x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
 
 # conv2_x
 x = bottleneck_Block(x, nb_filters=[64, 64, 256], strides=(1, 1), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 
 # conv3_x
 x = bottleneck_Block(x, nb_filters=[128, 128, 512], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 
 # conv4_x
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 
 # conv5_x
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 
 x = AveragePooling2D(pool_size=(7, 7))(x)
 x = Flatten()(x)
 x = Dense(128, activation='relu')(x)
 return Model(inpt, x)
 
def generator(imgs, batch_size):
 """
 自定义迭代器
 :param imgs: 列表,每个包含一对矩阵以及label
 :param batch_size:
 :return:
 """
 while 1:
  random.shuffle(imgs)
  li = imgs[:batch_size]
  pairs = []
  labels = []
  for i in li:
   img1 = i[0]
   img2 = i[1]
   im1 = cv2.imread(img1)
   im2 = cv2.imread(img2)
   if im1 is None or im2 is None:
    continue
   label = int(i[2])
   img1 = processImg(img1)
   img2 = processImg(img2)
   pairs.append([img1, img2])
   labels.append(label)
  pairs = np.array(pairs)
  labels = np.array(labels)
  yield [pairs[:, 0], pairs[:, 1]], labels
 
input_shape = (im_width, im_height, 3)
base_network = resnet_50()
 
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
 
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
 
distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])
with tf.device("/gpu:0"):
 model = Model([input_a, input_b], distance)
 # train
 rms = RMSprop()
 rows = csv.reader(open(csv_path, 'r'), delimiter=',')
 imgs = list(rows)
 checkpoint = ModelCheckpoint(filepath=model_result+'flag_{epoch:03d}.h5', verbose=1)
 model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
 model.fit_generator(generator(imgs, batch_size), epochs=epochs, steps_per_epoch=iterations, callbacks=[checkpoint])

用了回调函数保存了每一个epoch后的模型,也可以保存最好的,之后需要对模型进行测试。

测试时直接用load_model会报错,而应该变成如下形式调用:

model = load_model(model_path,custom_objects={'contrastive_loss': contrastive_loss }) #选取自己的.h模型名称

emmm,到这里,就成功训练测试完了~~~写的比较粗,因为这个代码在官方给的mnist上的改动不大,只是方便大家用自己的数据集,大家如果有更好的方法可以提出意见~~~希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python入门前的第一课 python怎样入门
Mar 06 Python
使用python编写udp协议的ping程序方法
Apr 22 Python
利用selenium爬虫抓取数据的基础教程
Jun 10 Python
pyqt 实现为长内容添加滑轮 scrollArea
Jun 19 Python
python+opencv实现摄像头调用的方法
Jun 22 Python
python字典排序的方法
Oct 12 Python
完美解决pyinstaller打包报错找不到依赖pypiwin32或pywin32-ctypes的错误
Apr 01 Python
Python新手学习装饰器
Jun 04 Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 Python
详解Python多线程下的list
Jul 03 Python
opencv 阈值分割的具体使用
Jul 08 Python
Pandas-DataFrame知识点汇总
Mar 16 Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
You might like
《OVERLORD》第四季,终于等到你!
2020/03/02 日漫
松下Panasonic RF-B65电路分析
2021/03/02 无线电
咖啡知识 咖啡养豆要养多久 排气又是什么
2021/03/06 新手入门
简单实现限定phpmyadmin访问ip的方法
2013/03/05 PHP
CentOS下与Apache连接的PHP多版本共存方案实现详解
2015/12/19 PHP
详解php几行代码实现CSV格式文件输出
2017/07/01 PHP
js 函数的执行环境和作用域链的深入解析
2009/11/01 Javascript
jquery之empty()与remove()区别说明
2010/09/10 Javascript
扩展easyui.datagrid,添加数据loading遮罩效果代码
2010/11/02 Javascript
浅析js中取绝对值的2种方法
2013/07/09 Javascript
JavaScript操作Oracle数据库示例
2015/03/06 Javascript
jQuery实现的经典滑动门效果
2015/09/22 Javascript
Vue移动端实现图片上传及超过1M压缩上传
2019/12/23 Javascript
在antd中setFieldsValue和defaultVal的用法
2020/10/29 Javascript
[01:16:13]DOTA2-DPC中国联赛 正赛 SAG vs Dragon BO3 第一场 2月22日
2021/03/11 DOTA
Python的lambda匿名函数的简单介绍
2013/04/25 Python
搭建Python的Django框架环境并建立和运行第一个App的教程
2016/07/02 Python
python中的字典操作及字典函数
2018/01/03 Python
Python实现的视频播放器功能完整示例
2018/02/01 Python
Python中反射和描述器总结
2018/09/23 Python
Python打包方法Pyinstaller的使用
2018/10/09 Python
解决使用PyCharm时无法启动控制台的问题
2019/01/19 Python
Python学习笔记之视频人脸检测识别实例教程
2019/03/06 Python
Python 网络编程之TCP客户端/服务端功能示例【基于socket套接字】
2019/10/12 Python
python如何通过twisted搭建socket服务
2020/02/03 Python
Python图片处理模块PIL操作方法(pillow)
2020/04/07 Python
python使用nibabel和sitk读取保存nii.gz文件实例
2020/07/01 Python
浅谈Python 命令行参数argparse写入图片路径操作
2020/07/12 Python
10行Python代码实现Web自动化管控的示例代码
2020/08/14 Python
python属于哪种语言
2020/08/16 Python
在使用非全零作为空指针内部表达的机器上, NULL是如何定义
2014/11/09 面试题
室内设计专业学生的自我评价分享
2013/11/27 职场文书
就业推荐表自我鉴定范文
2014/03/21 职场文书
思想作风纪律整顿心得体会
2014/09/04 职场文书
分析MySQL优化 index merge 后引起的死锁
2022/04/19 MySQL
python如何读取和存储dict()与.json格式文件
2022/06/25 Python