keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实例一个类背后发生了什么
Feb 09 Python
利用python编写一个图片主色转换的脚本
Dec 07 Python
Python给定一个句子倒序输出单词以及字母的方法
Dec 20 Python
python sort、sort_index方法代码实例
Mar 28 Python
详解Python:面向对象编程
Apr 10 Python
Python编程学习之如何判断3个数的大小
Aug 07 Python
PyCharm取消波浪线、下划线和中划线的实现
Mar 03 Python
python 连续不等式语法糖实例
Apr 15 Python
python使用matplotlib:subplot绘制多个子图的示例
Sep 24 Python
python多线程和多进程关系详解
Dec 14 Python
Python 多进程原理及实现
Dec 21 Python
python的scipy.stats模块中正态分布常用函数总结
Feb 19 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
PHP CURL 多线程操作代码实例
2015/05/13 PHP
ThinkPHP中session函数详解
2016/09/14 PHP
PHP实现表单提交数据的验证处理功能【防SQL注入和XSS攻击等】
2017/07/21 PHP
PHP设计模式之单例模式原理与实现方法分析
2018/04/25 PHP
PHP下载文件函数与用法示例
2019/09/27 PHP
一段利用WSH修改和查看IP配置的代码
2008/05/11 Javascript
JavaScript prototype对象的属性说明
2010/03/13 Javascript
ie6下png图片背景不透明的解决办法使用js实现
2013/01/11 Javascript
Js-$.extend扩展方法使方法参数更灵活
2013/01/15 Javascript
javascript定义变量时带var与不带var的区别分析
2015/01/12 Javascript
jQuery幻灯片带缩略图轮播效果代码分享
2015/08/17 Javascript
EasyUI在表单提交之前进行验证的实例代码
2016/06/24 Javascript
使用Bootrap和Vue实现仿百度搜索功能
2017/10/26 Javascript
React props和state属性的具体使用方法
2018/04/12 Javascript
vue-router 源码之实现一个简单的 vue-router
2018/07/02 Javascript
原生js实现form表单序列化的方法
2018/08/02 Javascript
Angular2 自定义表单验证器的实现方法
2018/12/14 Javascript
JavaScript中十种一步拷贝数组的方法实例详解
2019/04/22 Javascript
Vue组件跨层级获取组件操作
2020/07/27 Javascript
[01:09]2014DOTA2国际邀请赛 TI4西雅图DOTA2 中国美女coser加油助威
2014/07/20 DOTA
[50:45]2018DOTA2亚洲邀请赛 4.6 淘汰赛 VP vs TNC 第一场
2018/04/10 DOTA
Python中的条件判断语句基础学习教程
2016/02/07 Python
Tensorflow之Saver的用法详解
2018/04/23 Python
对Python发送带header的http请求方法详解
2019/01/02 Python
Django实现文章详情页面跳转代码实例
2020/09/16 Python
Django-silk性能测试工具安装及使用解析
2020/11/28 Python
IWOOT美国:新奇的小玩意
2018/04/27 全球购物
River Island美国官网:英国高街时尚品牌
2018/09/04 全球购物
大学生自我评价怎样写好
2013/10/23 职场文书
资料员岗位职责
2013/11/17 职场文书
电子商务专员岗位职责
2013/12/11 职场文书
奥巴马上海演讲稿
2014/09/10 职场文书
2014年物业管理工作总结
2014/11/21 职场文书
学校运动会感想
2015/08/10 职场文书
小学四年级班主任工作经验交流材料
2015/11/02 职场文书
python中super()函数的理解与基本使用
2021/08/30 Python