使用keras实现孪生网络中的权值共享教程


Posted in Python onJune 11, 2020

首先声明,这里的权值共享指的不是CNN原理中的共享权值,而是如何在构建类似于Siamese Network这样的多分支网络,且分支结构相同时,如何使用keras使分支的权重共享。

Functional API

为达到上述的目的,建议使用keras中的Functional API,当然Sequential 类型的模型也可以使用,本篇博客将主要以Functional API为例讲述。

keras的多分支权值共享功能实现,官方文档介绍

上面是官方的链接,本篇博客也是基于上述官方文档,实现的此功能。(插一句,keras虽然有中文文档,但中文文档已停更,且中文文档某些函数介绍不全,建议直接看英文官方文档)

不共享参数的模型

以MatchNet网络结构为例子,为方便显示,将卷积模块个数减为2个。首先是展示不共享参数的模型,以便观看完整的网络结构。

整体的网络结构如下所示:

代码包含两部分,第一部分定义了两个函数,FeatureNetwork()生成特征提取网络,ClassiFilerNet()生成决策网络或称度量网络。网络结构的可视化在博客末尾。在ClassiFilerNet()函数中,可以看到调用了两次FeatureNetwork()函数,keras.models.Model也被使用的两次,因此生成的input1和input2是两个完全独立的模型分支,参数是不共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ---------------------函数功能区-------------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""
  input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
  input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
  for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
    layer.name = layer.name + str("_2")
  inp1 = input1.input
  inp2 = input2.input
  merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

# ---------------------主调区-------------------------
matchnet = ClassiFilerNet()
matchnet.summary() # 打印网络结构
plot_model(matchnet, to_file='G:/csdn攻略/picture/model.png') # 网络结构输出成png图片

共享参数的模型

FeatureNetwork()的功能和上面的功能相同,为方便选择,在ClassiFilerNet()函数中加入了判断是否使用共享参数模型功能,令reuse=True,便使用的是共享参数的模型。

关键地方就在,只使用的一次Model,也就是说只创建了一次模型,虽然输入了两个输入,但其实使用的是同一个模型,因此权重共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ----------------函数功能区-----------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
  # models = Activation('relu')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(reuse=False): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""

  if reuse:
    inp = Input(shape=(28, 28, 1), name='FeatureNet_ImageInput')
    models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
    models = Activation('relu')(models)
    models = MaxPool2D(pool_size=(3, 3))(models)

    models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
    # models = Activation('relu')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Flatten()(models)
    models = Dense(512)(models)
    models = Activation('relu')(models)
    model = Model(inputs=inp, outputs=models)

    inp1 = Input(shape=(28, 28, 1)) # 创建输入
    inp2 = Input(shape=(28, 28, 1)) # 创建输入2
    model_1 = model(inp1) # 孪生网络中的一个特征提取分支
    model_2 = model(inp2) # 孪生网络中的另一个特征提取分支
    merge_layers = concatenate([model_1, model_2]) # 进行融合,使用的是默认的sum,即简单的相加

  else:
    input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
    input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
    for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
      layer.name = layer.name + str("_2")
    inp1 = input1.input
    inp2 = input2.input
    merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

如何看是否真的是权值共享呢?直接对比特征提取部分的网络参数个数!

不共享参数模型的参数数量:

使用keras实现孪生网络中的权值共享教程

共享参数模型的参数总量

使用keras实现孪生网络中的权值共享教程

共享参数模型中的特征提取部分的参数量为:

使用keras实现孪生网络中的权值共享教程

由于截图限制,不共享参数模型的特征提取网络参数数量不再展示。其实经过计算,特征提取网络部分的参数数量,不共享参数模型是共享参数的两倍。两个网络总参数量的差值就是,共享模型中,特征提取部分的参数的量

网络结构可视化

不共享权重的网络结构

使用keras实现孪生网络中的权值共享教程

共享参数的网络结构,其中model_1代表的就是特征提取部分。

使用keras实现孪生网络中的权值共享教程

以上这篇使用keras实现孪生网络中的权值共享教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3安装Pymongo详细步骤
May 26 Python
Python3编程实现获取阿里云ECS实例及监控的方法
Aug 18 Python
Python实现PS滤镜功能之波浪特效示例
Jan 26 Python
Python及Django框架生成二维码的方法分析
Jan 31 Python
Tensorflow实现卷积神经网络用于人脸关键点识别
Mar 05 Python
在自动化中用python实现键盘操作的方法详解
Jul 19 Python
pytorch 固定部分参数训练的方法
Aug 17 Python
用Python去除图像的黑色或白色背景实例
Dec 12 Python
python实现图片横向和纵向拼接
Mar 05 Python
python exit出错原因整理
Aug 31 Python
详解Python3 定义一个跨越多行的字符串的多种方法
Sep 06 Python
pytorch 实现在测试的时候启用dropout
May 27 Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
keras load model时出现Missing Layer错误的解决方式
Jun 11 #Python
Pyinstaller加密打包应用的示例代码
Jun 11 #Python
解决keras加入lambda层时shape的问题
Jun 11 #Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 #Python
You might like
php 伪造本地文件包含漏洞的代码
2011/11/03 PHP
AES加解密在php接口请求过程中的应用示例
2016/10/26 PHP
javascript 节点遍历函数
2010/03/28 Javascript
浏览器解析js生成的html出现样式问题的解决方法
2012/04/16 Javascript
JQuery DataTable删除行后的页面更新利用Ajax解决
2013/05/17 Javascript
javascript结合Canvas 实现简易的圆形时钟
2015/03/11 Javascript
代码分析jQuery四种静态方法使用
2015/07/23 Javascript
javascript实现3D变换的立体圆圈实例
2015/08/06 Javascript
Javascript编程中几种继承方式比较分析
2015/11/28 Javascript
Jquery EasyUI实现treegrid上显示checkbox并取选定值的方法
2016/04/29 Javascript
jQuery progressbar通过Ajax请求实现后台进度实时功能
2016/10/11 Javascript
NODE.JS跨域问题的完美解决方案
2016/10/20 Javascript
微信小程序 下拉列表的实现实例代码
2017/03/08 Javascript
基于js 字符串indexof与search方法的区别(详解)
2017/12/04 Javascript
深入理解JS异步编程-Promise
2019/06/03 Javascript
探究Python的Tornado框架对子域名和泛域名的支持
2015/05/02 Python
python数据结构之图深度优先和广度优先实例详解
2015/07/08 Python
python实现音乐下载器
2018/04/15 Python
Python Matplotlib实现三维数据的散点图绘制
2019/03/19 Python
python将字符串转换成json的方法小结
2019/07/09 Python
对Python函数设计规范详解
2019/07/19 Python
python模拟键盘输入 切换键盘布局过程解析
2019/08/15 Python
python+pygame实现坦克大战
2019/09/10 Python
信号生成及DFT的python实现方式
2020/02/25 Python
Windows下Anaconda和PyCharm的安装与使用详解
2020/04/23 Python
Python字符串函数strip()原理及用法详解
2020/07/23 Python
详解修改Anaconda中的Jupyter Notebook默认工作路径的三种方式
2021/01/24 Python
Perricone MD裴礼康美国官网:抗衰老护肤品
2016/09/26 全球购物
Maisons du Monde德国:法国家具和装饰的市场领导者
2019/07/26 全球购物
波兰在线运动商店:YesSport
2020/07/23 全球购物
如何实现jdbc性能优化
2012/07/30 面试题
万年牢教学反思
2014/02/15 职场文书
辅导员评语
2014/05/04 职场文书
乡镇党的群众路线教育实践活动领导班子对照检查材料
2014/09/25 职场文书
小学见习报告
2014/10/31 职场文书
2014年社区卫生工作总结
2014/12/18 职场文书