keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现网页链接提取的方法分享
Feb 25 Python
Python内置函数 next的具体使用方法
Nov 24 Python
python 判断网络连通的实现方法
Apr 22 Python
解决pandas无法在pycharm中使用plot()方法显示图像的问题
May 24 Python
Django 表单模型选择框如何使用分组
May 16 Python
Python调用C语言的实现
Jul 26 Python
pytorch自定义初始化权重的方法
Aug 17 Python
python性能测量工具cProfile使用解析
Sep 26 Python
python用requests实现http请求代码实例
Oct 31 Python
一个非常简单好用的Python图形界面库(PysimpleGUI)
Dec 28 Python
Python Matplotlib绘制等高线图与渐变色扇形图
Apr 14 Python
Python创建SQL数据库流程逐步讲解
Sep 23 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
ThinkPHP3.0略缩图不能保存到子目录的解决方法
2012/09/30 PHP
php微信公众号开发之微信企业付款给个人
2018/10/04 PHP
php微信公众号开发之关键词回复
2018/10/20 PHP
FormValidate 表单验证功能代码更新并提供下载
2008/08/23 Javascript
动态加载图片路径 保持JavaScript控件的相对独立性
2010/09/03 Javascript
js特殊字符过滤的示例代码
2014/03/05 Javascript
jQuery实现表格展开与折叠的方法
2015/05/04 Javascript
在JS中操作时间之getUTCMilliseconds()方法的使用
2015/06/10 Javascript
详解JavaScript中数组和字符串的lastIndexOf()方法使用
2016/03/13 Javascript
功能强大的Bootstrap组件(结合js)
2016/08/03 Javascript
jQuery tagsinput在h5邮件客户端中应用详解
2016/09/26 Javascript
js实现图片左右滚动效果
2017/02/27 Javascript
CentOS环境中MySQL修改root密码方法
2018/01/07 Javascript
Angular学习教程之RouterLink花式跳转
2018/05/03 Javascript
Intellij IDEA搭建vue-cli项目的方法步骤
2018/10/20 Javascript
基于JavaScript实现每日签到打卡轨迹功能
2018/11/29 Javascript
微信小程序列表时间戳转换实现过程解析
2019/10/12 Javascript
Vue生命周期activated之返回上一页不重新请求数据操作
2020/07/26 Javascript
vue 解决IOS10低版本白屏的问题
2020/11/17 Javascript
如何将python中的List转化成dictionary
2016/08/15 Python
用Python将动态GIF图片倒放播放的方法
2016/11/02 Python
python利用Guetzli批量压缩图片
2017/03/23 Python
详解Python import方法引入模块的实例
2017/08/02 Python
Python解决抛小球问题 求小球下落经历的距离之和示例
2018/02/01 Python
python网络爬虫学习笔记(1)
2018/04/09 Python
python画图——实现在图上标注上具体数值的方法
2019/07/08 Python
音频处理 windows10下python三方库librosa安装教程
2020/06/20 Python
HTML4和HTML5之间除了相似以外的10个主要不同
2012/12/13 HTML / CSS
Linux不知道文件后缀名怎么判断文件类型
2014/08/21 面试题
应聘自荐书
2013/10/08 职场文书
机电一体化毕业生求职信
2013/11/02 职场文书
课改先进个人汇报材料
2014/01/26 职场文书
幼儿园课题方案
2014/06/09 职场文书
2014年预备党员群众路线教育实践活动对照检查材料思想汇报
2014/10/02 职场文书
Mac M1安装mnmp (Mac+Nginx+MySQL+PHP) 开发环境
2021/03/29 PHP
PostgreSQL数据库去除重复数据和运算符的基本查询操作
2022/04/12 PostgreSQL