keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python编程开发之textwrap文本样式处理技巧
Nov 13 Python
Python构建XML树结构的方法示例
Jun 30 Python
django中的HTML控件及参数传递方法
Mar 20 Python
python如何通过twisted实现数据库异步插入
Mar 20 Python
Python SMTP发送邮件遇到的一些问题及解决办法
Oct 24 Python
Pandas的read_csv函数参数分析详解
Jul 02 Python
浅析python中while循环和for循环
Nov 19 Python
Python生态圈图像格式转换问题(推荐)
Dec 02 Python
pandas factorize实现将字符串特征转化为数字特征
Dec 19 Python
利用django创建一个简易的博客网站的示例
Sep 29 Python
Python爬虫之Selenium设置元素等待的方法
Dec 04 Python
python常见的占位符总结及用法
Jul 02 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
对PHP PDO的一些认识小结
2015/01/23 PHP
Linux操作系统安装LAMP环境
2015/06/26 PHP
PHP面向对象学习之parent::关键字
2017/01/18 PHP
Laravel框架实现即点即改功能的方法分析
2019/10/31 PHP
IE8 引入跨站数据获取功能说明
2008/07/22 Javascript
js 解决“options为空或不是对象”
2008/12/22 Javascript
jQuery 位置插件
2008/12/25 Javascript
firefox插件Firebug的使用教程
2010/01/02 Javascript
一个级联菜单代码学习及removeClass与addClass的应用
2013/01/24 Javascript
Node.js生成HttpStatusCode辅助类发布到npm
2013/04/09 Javascript
javascript代码运行不出来执行错误的可能情况整理
2013/10/18 Javascript
angularJS提交表单(form)
2015/02/09 Javascript
浅析node连接数据库(express+mysql)
2015/11/30 Javascript
JavaScript学习笔记--常用的互动方法
2016/12/07 Javascript
vue项目tween方法实现返回顶部的示例代码
2018/03/02 Javascript
Angular中的ng-template及angular 使用ngTemplateOutlet 指令的方法
2018/08/08 Javascript
vue element-ui读取pdf文件的方法
2019/11/26 Javascript
js实现省级联动(数据结构优化)
2020/07/17 Javascript
微信小程序入门之指南针
2020/10/22 Javascript
[01:02:00]DOTA2-DPC中国联赛 正赛 Elephant vs IG BO3 第三场 1月24日
2021/03/11 DOTA
对Python的zip函数妙用,旋转矩阵详解
2018/12/13 Python
python matplotlib库绘制散点图例题解析
2019/08/10 Python
Python:二维列表下标互换方式(矩阵转置)
2019/12/02 Python
python GUI库图形界面开发之PyQt5计数器控件QSpinBox详细使用方法与实例
2020/02/28 Python
基于python requests selenium爬取excel vba过程解析
2020/08/12 Python
python 元组和列表的区别
2020/12/30 Python
俄罗斯最大的在线手表商店:Bestwatch.ru
2020/01/11 全球购物
管理心得体会
2013/12/28 职场文书
领导干部整治奢华浪费之风思想汇报
2014/10/07 职场文书
工作业绩不及格检讨书
2014/10/28 职场文书
2014幼儿园班主任工作总结
2014/12/04 职场文书
2014保险公司内勤工作总结
2014/12/16 职场文书
幼儿园端午节活动总结
2015/05/05 职场文书
刑事撤诉申请书
2015/05/18 职场文书
2015年幼儿园学期工作总结
2015/05/22 职场文书
Win11怎么跳过联网验机 ?Win11跳过联网验机激活教程
2022/04/05 数码科技