keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python下使用Txt2Html实现网页过滤代理的教程
Apr 11 Python
在Python的Flask框架中使用模版的入门教程
Apr 20 Python
Python合并字符串的3种方法
May 21 Python
举例讲解Django中数据模型访问外键值的方法
Jul 21 Python
Python中表达式x += y和x = x+y 的区别详解
Jun 20 Python
Python抓取聚划算商品分析页面获取商品信息并以XML格式保存到本地
Feb 23 Python
Python 读取指定文件夹下的所有图像方法
Apr 27 Python
python调用opencv实现猫脸检测功能
Jan 15 Python
利用pyinstaller打包exe文件的基本教程
May 02 Python
如何用Python破解wifi密码过程详解
Jul 12 Python
python创建属于自己的单词词库 便于背单词
Jul 30 Python
只用40行Python代码就能写出pdf转word小工具
May 31 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
Phpbean路由转发的php代码
2008/01/10 PHP
PHP DataGrid 实现代码
2009/08/12 PHP
基于MySQL到MongoDB简易对照表的详解
2013/06/03 PHP
php ctype函数中文翻译和示例
2014/03/21 PHP
PHP登录环节防止sql注入的方法浅析
2014/06/30 PHP
PHP使用Pthread实现的多线程操作实例
2015/11/14 PHP
如何实现动态删除javascript函数
2007/05/27 Javascript
jQuery EasyUI API 中文文档 - Draggable 可拖拽
2011/09/29 Javascript
JS字符串的切分用法实例
2016/02/22 Javascript
js利用clipboardData实现截屏粘贴功能
2016/10/12 Javascript
vue2.0获取自定义属性的值
2017/03/28 Javascript
vue实现图片滚动的示例代码(类似走马灯效果)
2018/03/03 Javascript
vue项目中使用ueditor的实例讲解
2018/03/05 Javascript
vue.js+element-ui动态配置菜单的实例
2018/09/07 Javascript
Python 文件重命名工具代码
2009/07/26 Python
python实现在windows下操作word的方法
2015/04/28 Python
Python3实现爬取指定百度贴吧页面并保存页面数据生成本地文档的方法
2018/04/22 Python
python高级特性和高阶函数及使用详解
2018/10/17 Python
python画图--输出指定像素点的颜色值方法
2019/07/03 Python
python选取特定列 pandas iloc,loc,icol的使用详解(列切片及行切片)
2019/08/06 Python
用Pelican搭建一个极简静态博客系统过程解析
2019/08/22 Python
Python Print实现在输出中插入变量的例子
2019/12/25 Python
jupyter notebook 调用环境中的Keras或者pytorch教程
2020/04/14 Python
解决pycharm导入本地py文件时,模块下方出现红色波浪线的问题
2020/06/01 Python
HTML5样式控制示例代码
2013/11/27 HTML / CSS
波兰在线儿童和婴儿用品零售商:pinkorblue
2019/06/29 全球购物
工程力学专业毕业生求职信
2013/10/06 职场文书
2014年学习厉行节约反对浪费思想汇报
2014/09/10 职场文书
公司辞职信模板
2015/05/13 职场文书
导游词之西安骊山
2019/12/03 职场文书
2020年个人安全保证书参考模板
2020/01/08 职场文书
pytorch 实现多个Dataloader同时训练
2021/05/29 Python
Golang jwt身份认证
2022/04/20 Golang
总结三种用 Python 作为小程序后端的方式
2022/05/02 Python
JS前端可视化canvas动画原理及其推导实现
2022/08/05 Javascript
CSS list-style-type属性使用方法
2023/05/21 HTML / CSS