keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python操作MySQL数据库具体方法
Oct 28 Python
python实现ftp客户端示例分享
Feb 17 Python
编写Python脚本来获取Google搜索结果的示例
May 04 Python
python正则中最短匹配实现代码
Jan 16 Python
Python装饰器的执行过程实例分析
Jun 04 Python
Python实现网站表单提交和模板
Jan 15 Python
使用python批量修改文件名的方法(视频合并时)
Mar 24 Python
使用Python制作一个打字训练小工具
Oct 01 Python
Python 实现自动导入缺失的库
Oct 29 Python
python 多维高斯分布数据生成方式
Dec 09 Python
服务器端jupyter notebook映射到本地浏览器的操作
Apr 14 Python
pycharm 2018 激活码及破解补丁激活方式
Sep 21 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
用PHP调用数据库的存贮过程!
2006/10/09 PHP
php5.3 废弃函数小结
2010/05/16 PHP
PHP数字前补0的自带函数sprintf 和number_format的用法(详解)
2017/02/06 PHP
thinkPHP实现签到功能的方法
2017/03/15 PHP
初试jQuery EasyUI 使用介绍
2010/04/01 Javascript
jQuery效果 slideToggle() 方法(在隐藏和显示之间切换)
2011/06/28 Javascript
基于jQuery的星级评分插件
2011/08/12 Javascript
探讨jQuery的ajax使用场景(c#)
2013/12/03 Javascript
javaScript使用EL表达式的几种方式
2014/05/27 Javascript
JS根据生日算年龄的方法
2015/05/05 Javascript
学习JavaScript设计模式之责任链模式
2016/01/18 Javascript
js倒计时简单实现代码
2016/08/11 Javascript
Vuejs仿网易云音乐实现听歌及搜索功能
2017/03/30 Javascript
socket.io学习教程之深入学习篇(三)
2017/04/29 Javascript
Node.js中sequelize时区的配置方法
2017/12/10 Javascript
JavaScript中严格判断NaN的方法
2018/02/16 Javascript
vue单页开发父子组件传值思路详解
2018/05/18 Javascript
详解Angular-ui-BootStrap组件的解释以及使用
2018/07/13 Javascript
Vuex 使用 v-model 配合 state的方法
2018/11/13 Javascript
Vue.js@2.6.10更新内置错误处机制Fundebug同步支持相应错误监控
2019/05/13 Javascript
原生js实现html手机端城市列表索引选择城市
2020/06/24 Javascript
python检测服务器是否正常
2014/02/16 Python
python实现二分查找算法
2017/09/21 Python
Python跨文件全局变量的实现方法示例
2017/12/10 Python
python利用跳板机ssh远程连接redis的方法
2019/02/19 Python
Python使用pyserial进行串口通信的实例
2019/07/02 Python
Flask中sqlalchemy模块的实例用法
2020/08/02 Python
快速一键生成Python爬虫请求头
2021/03/04 Python
实例教程 纯CSS3打造非常炫的加载动画效果
2014/11/05 HTML / CSS
阿拉伯时尚购物网站:Nisnass
2021/02/07 全球购物
单位门卫岗位职责
2013/12/20 职场文书
社区学习雷锋活动总结
2014/04/25 职场文书
服务承诺书格式
2014/05/21 职场文书
初一英语教学反思
2016/02/15 职场文书
Python图片验证码降噪和8邻域降噪
2021/08/30 Python
了解Kubernetes中的Service和Endpoint
2022/04/01 Servers