keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python创建只读属性对象的方法(ReadOnlyObject)
Feb 10 Python
Python中统计函数运行耗时的方法
May 05 Python
python常见数制转换实例分析
May 09 Python
对PyTorch torch.stack的实例讲解
Jul 30 Python
解决win64 Python下安装PIL出错问题(图解)
Sep 03 Python
Scrapy框架使用的基本知识
Oct 21 Python
Python 面试中 8 个必考问题
Nov 16 Python
Python实现的爬取百度文库功能示例
Feb 16 Python
基于Python实现剪切板实时监控方法解析
Sep 11 Python
python语言中有算法吗
Jun 16 Python
Python生成并下载文件后端代码实例
Aug 31 Python
python 使用paramiko模块进行封装,远程操作linux主机的示例代码
Dec 03 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
php生成EXCEL的东东
2006/10/09 PHP
php获取用户浏览器版本的方法
2015/01/03 PHP
php提交表单发送邮件的方法
2015/03/20 PHP
利用PHP fsockopen 模拟POST/GET传送数据的方法
2015/09/22 PHP
Yii框架创建cronjob定时任务的方法分析
2017/05/23 PHP
javascript应用:Iframe自适应其加载的内容高度
2007/04/10 Javascript
IE中createElement需要注意的一个问题
2010/07/13 Javascript
jQuery针对各类元素操作基础教程
2014/08/29 Javascript
javascript将url中的参数加密解密代码
2014/11/17 Javascript
js时间戳转为日期格式的方法
2015/12/28 Javascript
JS简单实现tab切换效果的多窗口显示功能
2016/09/07 Javascript
js捕捉键盘事件和按键键值的方法
2016/10/10 Javascript
ASP.NET jquery ajax传递参数的实例
2016/11/02 Javascript
jQuery Validation Engine验证控件调用外部函数验证的方法
2017/01/18 Javascript
详解微信小程序 通过控制CSS实现view隐藏与显示
2017/05/24 Javascript
利用JavaScript如何查询某个值是否数组内
2017/07/30 Javascript
通过npm或yarn自动生成vue组件的方法示例
2019/02/12 Javascript
vue cli使用融云实现聊天功能的实例代码
2019/04/19 Javascript
[02:41]DOTA2英雄基础教程 冥魂大帝
2014/01/16 DOTA
[01:35]辉夜杯战队访谈宣传片—LGD
2015/12/25 DOTA
浅谈python中scipy.misc.logsumexp函数的运用场景
2016/06/23 Python
python删除服务器文件代码示例
2018/02/09 Python
对pandas里的loc并列条件索引的实例讲解
2018/11/15 Python
python3+django2开发一个简单的人员管理系统过程详解
2019/07/23 Python
django使用F方法更新一个对象多个对象字段的实现
2020/03/28 Python
如何在python中执行另一个py文件
2020/04/30 Python
Python中无限循环需要什么条件
2020/05/27 Python
keras读取h5文件load_weights、load代码操作
2020/06/12 Python
学习Python爬虫的几点建议
2020/08/05 Python
HTML5 实战PHP之Web页面表单设计
2011/10/09 HTML / CSS
Avène雅漾美国官方网站:敏感肌肤护理专家
2016/10/24 全球购物
班级活动策划书
2014/02/06 职场文书
电子工程专业毕业生求职信
2014/03/14 职场文书
现场活动策划方案
2014/08/22 职场文书
群众路线领导对照材料
2014/08/23 职场文书
Spring整合Mybatis的全过程
2021/06/28 Java/Android