keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
更改Python命令行交互提示符的方法
Jan 14 Python
使用Python读写文本文件及编写简单的文本编辑器
Mar 11 Python
深入浅析python中的多进程、多线程、协程
Jun 22 Python
python解决Fedora解压zip时中文乱码的方法
Sep 18 Python
Python爬取网页中的图片(搜狗图片)详解
Mar 23 Python
Python 将RGB图像转换为Pytho灰度图像的实例
Nov 14 Python
浅谈python中字典append 到list 后值的改变问题
May 04 Python
Python通用循环的构造方法实例分析
Dec 19 Python
关于阿里云oss获取sts凭证 app直传 python的实例
Aug 20 Python
解决django后台管理界面添加中文内容乱码问题
Nov 15 Python
matlab 计算灰度图像的一阶矩,二阶矩,三阶矩实例
Apr 22 Python
在tensorflow下利用plt画论文中loss,acc等曲线图实例
Jun 15 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
php下通过POST还是GET来传值
2008/06/05 PHP
解析php入库和出库
2013/06/25 PHP
php链表用法实例分析
2015/07/09 PHP
PHP简单操作MongoDB的方法(安装及增删改查)
2016/05/26 PHP
php_pdo 预处理语句详解
2016/11/21 PHP
PHP实现二维数组根据key进行排序的方法
2016/12/30 PHP
thinkphp5框架路由原理与用法详解
2020/02/11 PHP
jQuery Ajax之$.get()方法和$.post()方法
2009/10/12 Javascript
js中的onchange和onpropertychange (onchange无效的解决方法)
2014/03/08 Javascript
解决jquery中动态新增的元素节点无法触发事件问题的两种方法
2015/10/30 Javascript
jQueryUI Sortable 应用Demo(分享)
2017/09/07 jQuery
JS闭包的几种常见形式实例详解
2017/09/16 Javascript
JS使用tofixed与round处理数据四舍五入的区别
2017/10/25 Javascript
vue地区选择组件教程详解
2018/05/04 Javascript
JS控制GIF图片的停止与显示
2019/10/24 Javascript
使用 Github Actions 自动部署 Angular 应用到 Github Pages的方法
2020/07/20 Javascript
[01:22:42]2014 DOTA2华西杯精英邀请赛 5 24 DK VS LGD
2014/05/26 DOTA
Python读取图片EXIF信息类库介绍和使用实例
2014/07/10 Python
探究Python的Tornado框架对子域名和泛域名的支持
2015/05/02 Python
Python线性方程组求解运算示例
2018/01/17 Python
python使用多进程的实例详解
2018/09/19 Python
django2.0扩展用户字段示例
2019/02/13 Python
解析html5 canvas实现背景鼠标连线动态效果代码
2019/06/17 HTML / CSS
5分钟实现Canvas鼠标跟随动画背景
2019/11/18 HTML / CSS
德国旅游网站:weg.de
2018/06/03 全球购物
Stokke美国官方网店:高级儿童家具、推车、汽车座椅和配件
2020/06/06 全球购物
运动会开幕式邀请函
2014/02/03 职场文书
《再见了,亲人》教学反思
2014/02/26 职场文书
文化产业实施方案
2014/06/07 职场文书
教师自我剖析材料
2014/09/29 职场文书
毕业感言怎么写
2015/07/31 职场文书
六一儿童节致辞稿(3篇)
2019/07/11 职场文书
2019送给家人们的中秋节祝福语
2019/08/15 职场文书
广告文案的撰写技巧(实用干货)
2019/08/23 职场文书
MySQL 开窗函数
2022/02/15 MySQL
请求模块urllib之PYTHON爬虫的基本使用
2022/04/08 Python