keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置数据类型详解
Aug 18 Python
对Python3中的input函数详解
Apr 22 Python
对pandas数据判断是否为NaN值的方法详解
Nov 06 Python
Python 3.3实现计算两个日期间隔秒数/天数的方法示例
Jan 07 Python
python提取log文件内容并画出图表
Jul 08 Python
在django view中给form传入参数的例子
Jul 19 Python
Python+Pyqt实现简单GUI电子时钟
Feb 22 Python
Python 使用元类type创建类对象常见应用详解
Oct 17 Python
python装饰器使用实例详解
Dec 14 Python
pytorch查看torch.Tensor和model是否在CUDA上的实例
Jan 03 Python
python中return不返回值的问题解析
Jul 22 Python
分享python函数常见关键字
Apr 26 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
PHP封装的简单连接MongoDB类示例
2019/02/13 PHP
经典的解除许多网站无法复制文字的绝招
2006/12/31 Javascript
jQuery实现复选框全选/取消全选/反选及获得选择的值
2014/06/12 Javascript
javascript将url中的参数加密解密代码
2014/11/17 Javascript
jQuery实现“扫码阅读”功能
2015/01/21 Javascript
JavaScript在浏览器标题栏上显示当前日期和时间的方法
2015/03/19 Javascript
javascript实现网页子页面遍历回调的方法(涉及 window.frames、递归函数、函数上下文)
2015/07/27 Javascript
js带缩略图的图片轮播效果代码分享
2015/09/14 Javascript
JavaScript程序开发之JS代码放置的位置
2016/01/15 Javascript
js中获取 table节点各tr及td的内容简单实例
2016/10/14 Javascript
Angular2+国际化方案(ngx-translate)的示例代码
2017/08/23 Javascript
十分钟带你快速了解React16新特性
2017/11/10 Javascript
Vue中使用canvas方法总结
2019/02/12 Javascript
layui使用label标签的方法
2019/09/14 Javascript
Python实现的HTTP并发测试完整示例
2020/04/23 Python
python超简单解决约瑟夫环问题
2015/05/12 Python
python复制文件的方法实例详解
2015/05/22 Python
Python pass详细介绍及实例代码
2016/11/24 Python
解决Pycharm界面的子窗口不见了的问题
2019/01/17 Python
树莓派+摄像头实现对移动物体的检测
2019/06/22 Python
python 变量初始化空列表的例子
2019/11/28 Python
使用 Python ssh 远程登陆服务器的最佳方案
2020/03/06 Python
pycharm 关掉syntax检查操作
2020/06/09 Python
使用OpenCV对车道进行实时检测的实现示例代码
2020/06/19 Python
详解Sticky Footer 绝对底部的两种套路
2017/11/03 HTML / CSS
html5调用摄像头功能的实现代码
2018/05/07 HTML / CSS
Html5适配iphoneX刘海屏的简单实现
2019/04/09 HTML / CSS
Ralph Lauren法国官网:美国高品味时装品牌
2017/12/08 全球购物
俄罗斯购买内衣网站:Trusiki
2020/08/22 全球购物
PHP开发的一般流程
2013/08/13 面试题
知识就是力量演讲稿
2014/09/13 职场文书
感谢信范文大全
2015/01/23 职场文书
出国留学自荐信模板
2015/03/06 职场文书
起诉意见书范文
2015/05/19 职场文书
《曾国藩家书》读后感——读家书,立家风
2019/08/21 职场文书
新手必备Python开发环境搭建教程
2021/05/28 Python