keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解python多进程编程
Jun 12 Python
django项目运行因中文而乱码报错的几种情况解决
Nov 07 Python
python 请求服务器的实现代码(http请求和https请求)
May 25 Python
python实现傅里叶级数展开的实现
Jul 21 Python
Python批处理更改文件名os.rename的方法
Oct 26 Python
python实现点击按钮修改数据的方法
Jul 17 Python
Django models.py应用实现过程详解
Jul 29 Python
Python实现微信中找回好友、群聊用户撤回的消息功能示例
Aug 23 Python
Python搭建HTTP服务过程图解
Dec 14 Python
Keras 实现加载预训练模型并冻结网络的层
Jun 15 Python
python实现mask矩阵示例(根据列表所给元素)
Jul 30 Python
python的scipy.stats模块中正态分布常用函数总结
Feb 19 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
通过PHP修改Linux或Unix口令的方法分享
2012/01/30 PHP
PHP中Memcache操作类及用法实例
2014/12/12 PHP
php站内搜索关键词变亮的实现方法
2014/12/30 PHP
JS之小练习代码
2008/10/12 Javascript
重载toString实现JS HashMap分析
2011/03/13 Javascript
基于jQuery的动态表格插件
2011/03/28 Javascript
JQuery1.6 使用方法三
2011/11/23 Javascript
Jquery Ajax解析XML数据(同步及异步调用)简单实例
2014/02/12 Javascript
JavaScript 基本概念
2015/01/20 Javascript
微信中一些常用的js方法汇总
2015/03/12 Javascript
jQuery简单实现遍历数组的方法
2015/04/14 Javascript
AngularJS手动表单验证
2016/02/01 Javascript
Bootstrap Search Suggest使用例子
2016/12/21 Javascript
angularjs实现过滤并替换关键字小功能
2017/09/19 Javascript
layui2.0使用table+laypage实现真分页
2019/07/27 Javascript
基于layui轮播图满屏是高度自适应的解决方法
2019/09/16 Javascript
js判断在哪个浏览器打开项目的方法
2020/01/21 Javascript
vue 防止页面加载时看到花括号的解决操作
2020/11/09 Javascript
解决pycharm 远程调试 上传 helpers 卡住的问题
2019/06/27 Python
python实现微信自动回复及批量添加好友功能
2019/07/03 Python
Python3读写Excel文件(使用xlrd,xlsxwriter,openpyxl3种方式读写实例与优劣)
2020/02/13 Python
Python sqlalchemy时间戳及密码管理实现代码详解
2020/08/01 Python
Python编写memcached启动脚本代码实例
2020/08/14 Python
python实现单机五子棋
2020/08/28 Python
美国著名的户外用品品牌:L.L.Bean
2018/01/05 全球购物
Big Green Smile德国网上商店:提供各种天然产品
2018/05/23 全球购物
Set里的元素是不能重复的,那么用什么方法来区分重复与否呢?
2016/08/18 面试题
留学生如何写好自荐信
2013/12/27 职场文书
学生宿舍管理制度
2014/01/30 职场文书
仓管员岗位责任制
2014/02/19 职场文书
农村婚礼主持词
2014/03/13 职场文书
大学优秀班主任事迹材料
2014/05/02 职场文书
技能比武方案
2014/05/21 职场文书
涉密人员保密承诺书
2014/05/28 职场文书
护士个人年终总结
2015/02/13 职场文书