keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中取整的几种方法小结
Jan 06 Python
Python标准库sched模块使用指南
Jul 06 Python
PyQt5主窗口动态加载Widget实例代码
Feb 07 Python
Matplotlib 生成不同大小的subplots实例
May 25 Python
Python RabbitMQ消息队列实现rpc
May 30 Python
python实现图片批量压缩程序
Jul 23 Python
python中的常量和变量代码详解
Jul 25 Python
Python中shapefile转换geojson的示例
Jan 03 Python
Python 实现子类获取父类的类成员方法
Jan 11 Python
python opencv判断图像是否为空的实例
Jan 26 Python
Python3中的最大整数和最大浮点数实例
Jul 09 Python
纯python进行矩阵的相乘运算的方法示例
Jul 17 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
php快速url重写更新版[需php 5.30以上]
2010/04/25 PHP
php漏洞之跨网站请求伪造与防止伪造方法
2013/08/15 PHP
php实现mysql数据库操作类分享
2014/02/14 PHP
Zend Framework教程之Zend_Db_Table_Row用法实例分析
2016/03/21 PHP
Windows平台PHP+IECapt实现网页批量截图并创建缩略图功能详解
2019/08/02 PHP
基于JQuery的Pager分页器实现代码
2010/07/17 Javascript
javascript中定义私有方法说明(private method)
2014/01/27 Javascript
javascript读取Xml文件做一个二级联动菜单示例
2014/03/17 Javascript
js判断iframe内的网页是否滚动到底部触发事件
2014/03/18 Javascript
jquery form 隐藏的input 选择
2014/04/29 Javascript
Bootstrap导航中表单简单实现代码
2017/03/06 Javascript
jQuery实现上传图片前预览效果功能
2017/08/03 jQuery
详解动画插件wow.js的使用方法
2017/09/13 Javascript
详解Vue组件实现tips的总结
2017/11/01 Javascript
vue.js实现会动的简历(包含底部导航功能,编辑功能)
2019/04/08 Javascript
element-ui上传一张图片后隐藏上传按钮功能
2019/05/22 Javascript
axios如何取消重复无用的请求详解
2019/12/15 Javascript
angularjs模态框的使用代码实例
2019/12/20 Javascript
微信小程序开发打开另一个小程序的实现方法
2020/05/17 Javascript
[03:36]2014DOTA2 TI小组赛综述 八强诞生进军钥匙球馆
2014/07/15 DOTA
django开发之settings.py中变量的全局引用详解
2017/03/29 Python
用TensorFlow实现戴明回归算法的示例
2018/05/02 Python
Python图像处理之直线和曲线的拟合与绘制【curve_fit()应用】
2018/12/26 Python
详解Python做一个名片管理系统
2019/03/14 Python
对Python中小整数对象池和大整数对象池的使用详解
2019/07/09 Python
python3 tkinter实现添加图片和文本
2019/11/26 Python
解决pycharm上的jupyter notebook端口被占用问题
2019/12/17 Python
详解解决jupyter不能使用pytorch的问题
2021/02/18 Python
英国街头品牌:Bee Inspired Clothing
2018/02/12 全球购物
毕业生机械建模求职信
2013/10/14 职场文书
小学生秋游活动方案
2014/02/23 职场文书
管理提升方案
2014/06/04 职场文书
汉语言文学专业求职信
2014/06/19 职场文书
通报表扬范文
2015/01/17 职场文书
2015年财务部年度工作总结
2015/05/19 职场文书
nginx 防盗链防爬虫配置详解
2021/03/31 Servers