keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈Python NLP入门教程
Dec 25 Python
TensorFlow实现简单卷积神经网络
May 24 Python
钉钉群自定义机器人消息Python封装的实例
Feb 20 Python
PyQt弹出式对话框的常用方法及标准按钮类型
Feb 27 Python
Python流行ORM框架sqlalchemy安装与使用教程
Jun 04 Python
python 弹窗提示警告框MessageBox的实例
Jun 18 Python
Python跳出多重循环的方法示例
Jul 03 Python
python ftplib模块使用代码实例
Dec 31 Python
Python如何实现小程序 无限求和平均
Feb 18 Python
Keras 实现加载预训练模型并冻结网络的层
Jun 15 Python
python 基于opencv实现高斯平滑
Dec 18 Python
Pycharm 设置默认解释器路径和编码格式的操作
Feb 05 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
PHP SEO优化之URL优化方法
2011/04/21 PHP
PHP7新特性之抽象语法树(AST)带来的变化详解
2018/07/17 PHP
JQuery select控件的相关操作实现代码
2012/09/14 Javascript
JS中的substring和substr函数的区别说明
2013/05/07 Javascript
js获取当前路径的简单示例代码
2014/01/08 Javascript
用json方式实现在 js 中建立一个map
2014/05/02 Javascript
jQuery实现下拉框左右移动(全部移动,已选移动)
2016/04/15 Javascript
nodejs中向HTTP响应传送进程的输出
2017/03/19 NodeJs
vue-cli结合Element-ui基于cropper.js封装vue实现图片裁剪组件功能
2018/03/01 Javascript
Vue组件教程之Toast(Vue.extend 方式)详解
2019/01/27 Javascript
vue实现标签云效果的方法详解
2019/08/28 Javascript
vue随机验证码组件的封装实现
2020/02/19 Javascript
javascript设计模式 ? 策略模式原理与用法实例分析
2020/04/21 Javascript
python面向对象_详谈类的继承与方法的重载
2017/06/07 Python
神经网络(BP)算法Python实现及应用
2018/04/16 Python
Flask框架Flask-Principal基本用法实例分析
2018/07/23 Python
Python基本socket通信控制操作示例
2019/01/30 Python
Python OpenCV利用笔记本摄像头实现人脸检测
2020/08/20 Python
pandas的resample重采样的使用
2020/04/24 Python
Pycharm的Available Packages为空的解决方法
2020/09/18 Python
M1芯片安装python3.9.1的实现
2021/02/02 Python
js实现移动端H5页面手指滑动刻度尺功能
2017/11/16 HTML / CSS
韩国休闲女装品牌网站:ANAIS
2016/08/24 全球购物
First Aid Beauty官网:FAB急救面霜
2018/05/24 全球购物
学前教育教师求职自荐信
2013/09/22 职场文书
表彰先进集体通报
2014/01/12 职场文书
国庆节文艺活动方案
2014/02/03 职场文书
中学生评语大全
2014/04/18 职场文书
公司董事长岗位职责
2014/06/08 职场文书
法制宣传标语集锦
2014/06/25 职场文书
小学生关于梦想的演讲稿
2014/08/22 职场文书
会计专业自荐信范文
2015/03/05 职场文书
任命书格式模板
2015/09/22 职场文书
JavaScript继承的三种方法实例
2021/05/12 Javascript
5种 JavaScript 方式实现数组扁平化
2021/10/05 Javascript
试用1103暨1103、1101同门大比武 [ DAIWEI ]
2022/04/05 无线电