keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中logging模块的用法实例
Sep 29 Python
python实现ipsec开权限实例
Nov 11 Python
python获取标准北京时间的方法
Mar 24 Python
以一段代码为实例快速入门Python2.7
Mar 31 Python
python类和继承用法实例
Jul 07 Python
请不要重复犯我在学习Python和Linux系统上的错误
Dec 12 Python
python按综合、销量排序抓取100页的淘宝商品列表信息
Feb 24 Python
python用户管理系统
Mar 13 Python
Python实现的计算器功能示例
Apr 26 Python
python多线程同步实例教程
Aug 11 Python
python判断无向图环是否存在的示例
Nov 22 Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
全国FM电台频率大全 - 22 重庆市
2020/03/11 无线电
注意:php5.4删除了session_unregister函数
2013/08/05 PHP
PHP解析html类库simple_html_dom的转码bug
2014/05/22 PHP
PHP+MySQL实现输入页码跳转到指定页面功能示例
2018/06/01 PHP
JQuery实现用户名无刷新验证的小例子
2013/03/22 Javascript
jquery获得页面元素的坐标值实现思路及代码
2013/04/15 Javascript
基于jquery编写的横向自适应幻灯片切换特效的实例代码
2013/08/06 Javascript
解决checkbox的attr(checked)一直为undefined问题
2014/06/16 Javascript
使用AmplifyJS组件配合JavaScript进行编程的指南
2015/07/28 Javascript
jquery专业的导航菜单特效代码分享
2015/08/29 Javascript
JS实现可拖曳、可关闭的弹窗效果
2015/09/26 Javascript
js监听键盘事件的方法_原生和jquery的区别详解
2016/10/10 Javascript
详解Vue-cli 创建的项目如何跨域请求
2017/05/18 Javascript
基于jQuery对象和DOM对象和字符串之间的转化实例
2017/08/08 jQuery
element-ui 关于获取select 的label值方法
2018/08/24 Javascript
vue 实现边输入边搜索功能的实例讲解
2018/09/16 Javascript
angularJs使用ng-repeat遍历后选中某一个的方法
2018/09/30 Javascript
vue实现鼠标移入移出事件代码实例
2019/03/27 Javascript
微信小程序之数据绑定原理解析
2019/08/14 Javascript
详解微信小程序「渲染层网络层错误」的解决方法
2021/01/06 Javascript
[04:39]显微镜下的DOTA2第十三期—Pis卡尔个人秀
2014/04/04 DOTA
python django 增删改查操作 数据库Mysql
2017/07/27 Python
修复CentOS7升级Python到3.6版本后yum不能正确使用的解决方法
2018/01/26 Python
在Qt中正确的设置窗体的背景图片的几种方法总结
2019/06/19 Python
Linux下通过python获取本机ip方法示例
2019/09/06 Python
Python调用接口合并Excel表代码实例
2020/03/31 Python
python模拟哔哩哔哩滑块登入验证的实现
2020/04/24 Python
Python2手动安装更新pip过程实例解析
2020/07/16 Python
浅谈html5增强的页面元素
2016/06/14 HTML / CSS
html5简介及新增功能介绍
2020/05/18 HTML / CSS
我的中国梦演讲稿500字
2014/08/19 职场文书
五年级上册复习计划
2015/01/19 职场文书
高中教师个人工作总结
2015/02/10 职场文书
文言文辞职信
2015/02/28 职场文书
行政复议答复书
2015/07/01 职场文书
2019年销售部季度工作计划3篇
2019/10/09 职场文书