使用keras实现孪生网络中的权值共享教程


Posted in Python onJune 11, 2020

首先声明,这里的权值共享指的不是CNN原理中的共享权值,而是如何在构建类似于Siamese Network这样的多分支网络,且分支结构相同时,如何使用keras使分支的权重共享。

Functional API

为达到上述的目的,建议使用keras中的Functional API,当然Sequential 类型的模型也可以使用,本篇博客将主要以Functional API为例讲述。

keras的多分支权值共享功能实现,官方文档介绍

上面是官方的链接,本篇博客也是基于上述官方文档,实现的此功能。(插一句,keras虽然有中文文档,但中文文档已停更,且中文文档某些函数介绍不全,建议直接看英文官方文档)

不共享参数的模型

以MatchNet网络结构为例子,为方便显示,将卷积模块个数减为2个。首先是展示不共享参数的模型,以便观看完整的网络结构。

整体的网络结构如下所示:

代码包含两部分,第一部分定义了两个函数,FeatureNetwork()生成特征提取网络,ClassiFilerNet()生成决策网络或称度量网络。网络结构的可视化在博客末尾。在ClassiFilerNet()函数中,可以看到调用了两次FeatureNetwork()函数,keras.models.Model也被使用的两次,因此生成的input1和input2是两个完全独立的模型分支,参数是不共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ---------------------函数功能区-------------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""
  input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
  input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
  for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
    layer.name = layer.name + str("_2")
  inp1 = input1.input
  inp2 = input2.input
  merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

# ---------------------主调区-------------------------
matchnet = ClassiFilerNet()
matchnet.summary() # 打印网络结构
plot_model(matchnet, to_file='G:/csdn攻略/picture/model.png') # 网络结构输出成png图片

共享参数的模型

FeatureNetwork()的功能和上面的功能相同,为方便选择,在ClassiFilerNet()函数中加入了判断是否使用共享参数模型功能,令reuse=True,便使用的是共享参数的模型。

关键地方就在,只使用的一次Model,也就是说只创建了一次模型,虽然输入了两个输入,但其实使用的是同一个模型,因此权重共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ----------------函数功能区-----------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
  # models = Activation('relu')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(reuse=False): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""

  if reuse:
    inp = Input(shape=(28, 28, 1), name='FeatureNet_ImageInput')
    models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
    models = Activation('relu')(models)
    models = MaxPool2D(pool_size=(3, 3))(models)

    models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
    # models = Activation('relu')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Flatten()(models)
    models = Dense(512)(models)
    models = Activation('relu')(models)
    model = Model(inputs=inp, outputs=models)

    inp1 = Input(shape=(28, 28, 1)) # 创建输入
    inp2 = Input(shape=(28, 28, 1)) # 创建输入2
    model_1 = model(inp1) # 孪生网络中的一个特征提取分支
    model_2 = model(inp2) # 孪生网络中的另一个特征提取分支
    merge_layers = concatenate([model_1, model_2]) # 进行融合,使用的是默认的sum,即简单的相加

  else:
    input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
    input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
    for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
      layer.name = layer.name + str("_2")
    inp1 = input1.input
    inp2 = input2.input
    merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

如何看是否真的是权值共享呢?直接对比特征提取部分的网络参数个数!

不共享参数模型的参数数量:

使用keras实现孪生网络中的权值共享教程

共享参数模型的参数总量

使用keras实现孪生网络中的权值共享教程

共享参数模型中的特征提取部分的参数量为:

使用keras实现孪生网络中的权值共享教程

由于截图限制,不共享参数模型的特征提取网络参数数量不再展示。其实经过计算,特征提取网络部分的参数数量,不共享参数模型是共享参数的两倍。两个网络总参数量的差值就是,共享模型中,特征提取部分的参数的量

网络结构可视化

不共享权重的网络结构

使用keras实现孪生网络中的权值共享教程

共享参数的网络结构,其中model_1代表的就是特征提取部分。

使用keras实现孪生网络中的权值共享教程

以上这篇使用keras实现孪生网络中的权值共享教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
30分钟搭建Python的Flask框架并在上面编写第一个应用
Mar 30 Python
老生常谈python之鸭子类和多态
Jun 13 Python
将python代码和注释分离的方法
Apr 21 Python
pandas如何处理缺失值
Jul 31 Python
tensor和numpy的互相转换的实现示例
Aug 02 Python
Centos7 下安装最新的python3.8
Oct 28 Python
python定义类self用法实例解析
Jan 22 Python
将python文件打包exe独立运行程序方法详解
Feb 12 Python
python matplotlib:plt.scatter() 大小和颜色参数详解
Apr 14 Python
关于Keras Dense层整理
May 21 Python
Python QT组件库qtwidgets的使用
Nov 02 Python
Python的Tqdm模块实现进度条配置
Feb 24 Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
keras load model时出现Missing Layer错误的解决方式
Jun 11 #Python
Pyinstaller加密打包应用的示例代码
Jun 11 #Python
解决keras加入lambda层时shape的问题
Jun 11 #Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 #Python
You might like
彻底杜绝PHP的session cookie错误
2009/08/09 PHP
rrmdir php中递归删除目录及目录下的文件
2011/05/15 PHP
PHP基于imap获取邮件实例
2014/11/11 PHP
php插件Xajax使用方法详解
2017/08/31 PHP
PHP-X系列教程之内置函数的使用示例
2017/10/16 PHP
JavaScript 浏览器验证代码(来自discuz)
2010/07/17 Javascript
JQUERY 获取IFrame中对象及获取其父窗口中对象示例
2013/08/19 Javascript
JS生成不重复随机数组的函数代码
2014/06/10 Javascript
javascript实现字符串反转的方法
2015/02/05 Javascript
JavaScript的设计模式经典之代理模式
2016/02/24 Javascript
node.js实现博客小爬虫的实例代码
2016/10/08 Javascript
JS为什么说async/await是generator的语法糖详解
2019/07/11 Javascript
使用python生成杨辉三角形的示例代码
2018/08/29 Python
基于python实现名片管理系统
2018/11/30 Python
python实现石头剪刀布小游戏
2021/01/20 Python
基于python的BP神经网络及异或实现过程解析
2019/09/30 Python
Python调用scp向服务器上传文件示例
2019/12/22 Python
django执行原始查询sql,并返回Dict字典例子
2020/04/01 Python
localStorage、sessionStorage使用总结
2017/11/17 HTML / CSS
UGG美国官网:购买UGG雪地靴、拖鞋和鞋子
2017/12/31 全球购物
装修致歉信
2014/01/15 职场文书
开门红主持词
2014/04/02 职场文书
兴趣小组活动总结
2014/05/05 职场文书
公司会议策划方案
2014/05/17 职场文书
国际商贸专业自荐信
2014/06/09 职场文书
党的群众教育实践活动实施方案
2014/06/12 职场文书
门卫岗位职责说明书
2014/08/18 职场文书
法人代表证明书格式
2014/10/01 职场文书
个人工作总结范文2014
2014/11/07 职场文书
七年级上册语文教学计划
2015/01/22 职场文书
爱晚亭导游词
2015/02/09 职场文书
工作岗位职责范本
2015/02/15 职场文书
农村环境卫生倡议书
2015/04/29 职场文书
运动会致辞稿
2015/07/29 职场文书
SQLServer RANK() 排名函数的使用
2022/03/23 SQL Server
PostgreSQL出现死锁该如何解决
2022/05/30 PostgreSQL