keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的检测网站挂马程序
Nov 30 Python
python实现telnet客户端的方法
Apr 15 Python
Python matplotlib画图实例之绘制拥有彩条的图表
Dec 28 Python
Python 编程速成(推荐)
Apr 15 Python
全面了解django的缓存机制及使用方法
Jul 22 Python
python中自带的三个装饰器的实现
Nov 08 Python
python判断两个序列的成员是否一样的实例代码
Mar 01 Python
windows10环境下用anaconda和VScode配置的图文教程
Mar 30 Python
python爬虫学习笔记之Beautifulsoup模块用法详解
Apr 09 Python
python批量修改文件名的示例
Sep 27 Python
python 爬虫如何正确的使用cookie
Oct 27 Python
Python Http请求json解析库用法解析
Nov 28 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
收集的二十一个实用便利的PHP函数代码
2010/04/22 PHP
解析PHP自带的进位制之间的转换函数
2013/06/08 PHP
PHP伪造来源HTTP_REFERER的方法实例详解
2015/07/06 PHP
php中通用的excel导出方法实例
2017/12/30 PHP
PHPExcel 修改已存在Excel的方法
2018/05/03 PHP
PHP count()函数讲解
2019/02/03 PHP
YII框架常用技巧总结
2019/04/27 PHP
Yii 框架入口脚本示例分析
2020/05/19 PHP
Opacity.js
2007/01/22 Javascript
JavaScript实现从数组中选出和等于固定值的n个数
2014/09/03 Javascript
基于javascript实现的搜索时自动提示功能
2014/12/26 Javascript
jQuery实现带滚动导航效果的全屏滚动相册实例
2015/06/19 Javascript
javascript获取本机操作系统类型的方法
2015/08/13 Javascript
JavaScript编写检测用户所使用的浏览器的代码示例
2016/05/05 Javascript
Javascript动画效果(1)
2016/10/11 Javascript
jQuery上传多张图片带进度条样式(DEMO)
2017/03/02 Javascript
基于angular实现三级联动的生日插件
2017/05/12 Javascript
Javascript实现基本运算器
2017/07/15 Javascript
Vue项目全局配置页面缓存之按需读取缓存的实现详解
2018/08/01 Javascript
layui自定义插件citySelect实现省市区三级联动选择
2019/07/26 Javascript
vue前端和Django后端如何查询一定时间段内的数据
2021/02/28 Vue.js
通过python改变图片特定区域的颜色详解
2019/07/15 Python
python实现五子棋游戏(pygame版)
2020/01/19 Python
python numpy矩阵信息说明,shape,size,dtype
2020/05/22 Python
html5 利用canvas实现超级玛丽简单动画
2013/09/06 HTML / CSS
美国彩妆品牌:Coastal Scents
2017/04/01 全球购物
英国和爱尔兰最大的地毯零售商:Kukoon
2018/12/17 全球购物
开普敦通行证:Cape Town Pass
2019/07/18 全球购物
产品推广策划方案
2014/05/10 职场文书
门市房租房协议书
2014/12/04 职场文书
2015年护士节慰问信
2015/03/23 职场文书
奖励通知
2015/04/22 职场文书
阿甘正传观后感
2015/06/01 职场文书
催款函范本大全
2015/06/24 职场文书
python中Tkinter 窗口之输入框和文本框的实现
2021/04/12 Python
Python软件包安装的三种常见方法
2022/07/07 Python