keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之循环对象
Aug 30 Python
Python中实现对list做减法操作介绍
Jan 09 Python
Python读写ini文件的方法
May 28 Python
python文件的md5加密方法
Apr 06 Python
Python的socket模块源码中的一些实现要点分析
Jun 06 Python
mysql 之通过配置文件链接数据库
Aug 12 Python
python实现对任意大小图片均匀切割的示例
Dec 05 Python
Python 判断奇数偶数的方法
Dec 20 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 Python
浅谈matplotlib.pyplot与axes的关系
Mar 06 Python
Python函数生成器原理及使用详解
Mar 12 Python
150行python代码实现贪吃蛇游戏
Apr 24 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
php利用curl抓取新浪微博内容示例
2014/04/27 PHP
PHP实现AES256加密算法实例
2014/09/22 PHP
PHP全局使用Laravel辅助函数dd
2019/12/26 PHP
AJAX的跨域与JSONP(为文章自动添加短址的功能)
2010/01/17 Javascript
使用ImageMagick进行图片缩放、合成与裁剪(js+python)
2013/09/16 Javascript
js如何判断用户是在PC端和还是移动端访问
2014/04/24 Javascript
Javascript基础教程之for循环
2015/01/18 Javascript
JSON与XML优缺点对比分析
2015/07/17 Javascript
JavaScript SHA512&amp;SHA256加密算法详解
2015/08/11 Javascript
分享有关jQuery中animate、slide、fade等动画的连续触发、滞后反复执行的bug
2016/01/10 Javascript
jQuery实现的自适应焦点图效果完整实例
2016/08/24 Javascript
Bootstrap框架结合jQuery仿百度换肤功能实例解析
2016/09/17 Javascript
easyui导出excel无法弹出下载框的快速解决方法
2016/11/10 Javascript
值得分享的Bootstrap Table使用教程
2016/11/23 Javascript
easyui datebox 时间限制,datebox开始时间限制结束时间,datebox截止日期比起始日期大的实现代码
2017/01/12 Javascript
js实现登录框鼠标拖拽效果
2017/03/09 Javascript
JS中正则表达式全局匹配模式 /g用法详解
2017/04/01 Javascript
vue.js单页面应用实例的简单实现
2017/04/10 Javascript
vue.js声明式渲染和条件与循环基础知识
2017/07/31 Javascript
js实现单元格拖拽效果
2020/02/10 Javascript
原生js实现日期选择插件
2020/05/21 Javascript
python列表与元组详解实例
2013/11/01 Python
用Python制作检测Linux运行信息的工具的教程
2015/04/01 Python
Python实现高效求解素数代码实例
2015/06/30 Python
python3.5 + PyQt5 +Eric6 实现的一个计算器代码
2017/03/11 Python
Python实现求解括号匹配问题的方法
2018/04/17 Python
Python中函数参数调用方式分析
2018/08/09 Python
Python简单基础小程序的实例代码
2019/04/28 Python
python 生成器需注意的小问题
2020/09/29 Python
使用Python判断一个文件是否被占用的方法教程
2020/12/16 Python
个人能力自我鉴赏
2014/01/25 职场文书
2014公安机关纪律作风整顿思想汇报
2014/09/13 职场文书
2015年终个人政治思想工作总结
2015/11/24 职场文书
销区经理年终述职报告模板
2019/11/28 职场文书
vue实现列表拖拽排序的示例代码
2022/04/08 Vue.js
Spring boot admin 服务监控利器详解
2022/08/05 Java/Android