keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的字典和列表的使用中一些需要注意的地方
Apr 24 Python
Python判断字符串与大小写转换
Jun 08 Python
Python导出DBF文件到Excel的方法
Jul 25 Python
mysql 之通过配置文件链接数据库
Aug 12 Python
Python从单元素字典中获取key和value的实例
Dec 31 Python
twilio python自动拨打电话,播放自定义mp3音频的方法
Aug 08 Python
利用Python小工具实现3秒钟将视频转换为音频
Oct 29 Python
Django实现分页显示效果
Oct 31 Python
Python模块的制作方法实例分析
Dec 21 Python
如何通过Django使用本地css/js文件
Jan 20 Python
keras模型保存为tensorflow的二进制模型方式
May 25 Python
python如何查看网页代码
Jun 07 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
php采集文章中的图片获取替换到本地(实现代码)
2013/07/08 PHP
修改apache配置文件去除thinkphp url中的index.php
2014/01/17 PHP
thinkphp中多表查询中防止数据重复的sql语句(必看)
2016/09/22 PHP
Symfony2针对输入时间进行查询的方法分析
2017/06/28 PHP
PHP+MySQL实现模糊查询员工信息功能示例
2018/06/01 PHP
一实用的实现table排序的Javascript类库
2007/09/12 Javascript
javascript  Error 对象 错误处理
2008/05/18 Javascript
ECMAScript6中Set/WeakSet详解
2015/06/12 Javascript
jQuery实现的进度条效果
2015/07/15 Javascript
解决jQuery uploadify在非IE核心浏览器下无法上传
2015/08/05 Javascript
基于javascript代码实现通过点击图片显示原图片
2015/11/29 Javascript
以JavaScript来实现WordPress中的二级导航菜单的方法
2015/12/14 Javascript
在localStorage中存储对象数组并读取的方法
2016/09/24 Javascript
Angular和百度地图的结合实例代码
2016/10/19 Javascript
Vue制作Todo List网页
2017/04/26 Javascript
JS按钮闪烁功能的实现代码
2017/07/21 Javascript
vue+vuex+axios+echarts画一个动态更新的中国地图的方法
2017/12/19 Javascript
JS实现的汉字与Unicode码相互转化功能分析
2018/05/25 Javascript
微信小程序实现弹框效果
2020/05/26 Javascript
JavaScript中window和document用法详解
2020/07/28 Javascript
详谈Python高阶函数与函数装饰器(推荐)
2017/09/30 Python
Python实现多级目录压缩与解压文件的方法
2018/09/01 Python
使用PIL(Python-Imaging)反转图像的颜色方法
2019/01/24 Python
pandas计算最大连续间隔的方法
2019/07/04 Python
python matplotlib库绘制散点图例题解析
2019/08/10 Python
英国领先的男装设计师服装购物网站:Mainline Menswear
2018/02/04 全球购物
美国在线乐器和设备商店:Musician’s Friend
2018/07/06 全球购物
Boden英国官网:英国知名原创时装品牌
2018/11/06 全球购物
中国专业的音频分享平台:喜马拉雅
2019/05/24 全球购物
房地产销售员的自我评价分享
2013/12/04 职场文书
给医务人员表扬信
2014/01/12 职场文书
对公司合理化的建议书
2014/03/12 职场文书
2014年化验员工作总结
2014/11/18 职场文书
汽车转让协议书范本
2014/12/07 职场文书
2016大学生就业指导课心得体会
2016/01/15 职场文书
Oracle11g r2 卸载干净重装的详细教程(亲测有效已重装过)
2021/06/04 Oracle