使用keras实现孪生网络中的权值共享教程


Posted in Python onJune 11, 2020

首先声明,这里的权值共享指的不是CNN原理中的共享权值,而是如何在构建类似于Siamese Network这样的多分支网络,且分支结构相同时,如何使用keras使分支的权重共享。

Functional API

为达到上述的目的,建议使用keras中的Functional API,当然Sequential 类型的模型也可以使用,本篇博客将主要以Functional API为例讲述。

keras的多分支权值共享功能实现,官方文档介绍

上面是官方的链接,本篇博客也是基于上述官方文档,实现的此功能。(插一句,keras虽然有中文文档,但中文文档已停更,且中文文档某些函数介绍不全,建议直接看英文官方文档)

不共享参数的模型

以MatchNet网络结构为例子,为方便显示,将卷积模块个数减为2个。首先是展示不共享参数的模型,以便观看完整的网络结构。

整体的网络结构如下所示:

代码包含两部分,第一部分定义了两个函数,FeatureNetwork()生成特征提取网络,ClassiFilerNet()生成决策网络或称度量网络。网络结构的可视化在博客末尾。在ClassiFilerNet()函数中,可以看到调用了两次FeatureNetwork()函数,keras.models.Model也被使用的两次,因此生成的input1和input2是两个完全独立的模型分支,参数是不共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ---------------------函数功能区-------------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""
  input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
  input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
  for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
    layer.name = layer.name + str("_2")
  inp1 = input1.input
  inp2 = input2.input
  merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

# ---------------------主调区-------------------------
matchnet = ClassiFilerNet()
matchnet.summary() # 打印网络结构
plot_model(matchnet, to_file='G:/csdn攻略/picture/model.png') # 网络结构输出成png图片

共享参数的模型

FeatureNetwork()的功能和上面的功能相同,为方便选择,在ClassiFilerNet()函数中加入了判断是否使用共享参数模型功能,令reuse=True,便使用的是共享参数的模型。

关键地方就在,只使用的一次Model,也就是说只创建了一次模型,虽然输入了两个输入,但其实使用的是同一个模型,因此权重共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ----------------函数功能区-----------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
  # models = Activation('relu')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(reuse=False): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""

  if reuse:
    inp = Input(shape=(28, 28, 1), name='FeatureNet_ImageInput')
    models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
    models = Activation('relu')(models)
    models = MaxPool2D(pool_size=(3, 3))(models)

    models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
    # models = Activation('relu')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Flatten()(models)
    models = Dense(512)(models)
    models = Activation('relu')(models)
    model = Model(inputs=inp, outputs=models)

    inp1 = Input(shape=(28, 28, 1)) # 创建输入
    inp2 = Input(shape=(28, 28, 1)) # 创建输入2
    model_1 = model(inp1) # 孪生网络中的一个特征提取分支
    model_2 = model(inp2) # 孪生网络中的另一个特征提取分支
    merge_layers = concatenate([model_1, model_2]) # 进行融合,使用的是默认的sum,即简单的相加

  else:
    input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
    input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
    for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
      layer.name = layer.name + str("_2")
    inp1 = input1.input
    inp2 = input2.input
    merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

如何看是否真的是权值共享呢?直接对比特征提取部分的网络参数个数!

不共享参数模型的参数数量:

使用keras实现孪生网络中的权值共享教程

共享参数模型的参数总量

使用keras实现孪生网络中的权值共享教程

共享参数模型中的特征提取部分的参数量为:

使用keras实现孪生网络中的权值共享教程

由于截图限制,不共享参数模型的特征提取网络参数数量不再展示。其实经过计算,特征提取网络部分的参数数量,不共享参数模型是共享参数的两倍。两个网络总参数量的差值就是,共享模型中,特征提取部分的参数的量

网络结构可视化

不共享权重的网络结构

使用keras实现孪生网络中的权值共享教程

共享参数的网络结构,其中model_1代表的就是特征提取部分。

使用keras实现孪生网络中的权值共享教程

以上这篇使用keras实现孪生网络中的权值共享教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现股市信息下载的方法
Jun 15 Python
python实现的希尔排序算法实例
Jul 01 Python
5种Python单例模式的实现方式
Jan 14 Python
Python数据操作方法封装类实例
Jun 23 Python
python3获取当前文件的上一级目录实例
Apr 26 Python
pip命令无法使用的解决方法
Jun 12 Python
django orm 通过related_name反向查询的方法
Dec 15 Python
解决python flask中config配置管理的问题
Jul 26 Python
利用Pytorch实现简单的线性回归算法
Jan 15 Python
Python使用graphviz画流程图过程解析
Mar 31 Python
Python的PIL库中getpixel方法的使用
Apr 09 Python
python 基于pygame实现俄罗斯方块
Mar 02 Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
keras load model时出现Missing Layer错误的解决方式
Jun 11 #Python
Pyinstaller加密打包应用的示例代码
Jun 11 #Python
解决keras加入lambda层时shape的问题
Jun 11 #Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 #Python
You might like
盘点被央视点名过的日本动画电影 一部比一部强
2020/03/08 日漫
用Apache反向代理设置对外的WWW和文件服务器
2006/10/09 PHP
javascript some()函数用法详解
2014/11/13 PHP
php使用include 和require引入文件的区别
2017/02/16 PHP
php PDO实现的事务回滚示例
2017/03/23 PHP
PHP基于phpqrcode类生成二维码的方法详解
2018/03/14 PHP
PHP实现的mysql读写分离操作示例
2018/05/22 PHP
php微信公众号开发之快递查询
2018/10/20 PHP
PHP文件后缀不强制为.php方法
2019/03/31 PHP
写入cookie的JavaScript代码库 cookieLibrary.js
2009/10/24 Javascript
日历查询的算法 如何计算某一天是星期几
2012/12/12 Javascript
JavaScript实现将xml转换成html table表格的方法
2015/04/17 Javascript
javascript中 try catch用法
2015/08/16 Javascript
JS实现自动定时切换的简洁网页选项卡效果
2015/10/13 Javascript
基于jquery实现轮播特效
2016/04/22 Javascript
一句jQuery代码实现返回顶部效果(简单实用)
2016/12/28 Javascript
angular.js + require.js构建模块化单页面应用的方法步骤
2017/07/19 Javascript
简单实现vue验证码60秒倒计时功能
2017/10/11 Javascript
JavaScript数组,JSON对象实现动态添加、修改、删除功能示例
2018/05/26 Javascript
JS正则表达式常见用法实例详解
2018/06/19 Javascript
angular6根据environments配置文件更改开发所需要的环境的方法
2019/03/06 Javascript
JavaScript检测浏览器是否支持CSS变量代码实例
2020/04/03 Javascript
Python算法之栈(stack)的实现
2014/08/18 Python
python简单获取数组元素个数的方法
2015/07/13 Python
详解Django+Uwsgi+Nginx的生产环境部署
2018/06/25 Python
Python3 itchat实现微信定时发送群消息的实例代码
2019/07/12 Python
python扫描线填充算法详解
2020/02/19 Python
python如何调用java类
2020/07/05 Python
全网最全python库selenium自动化使用详细教程
2021/01/12 Python
matplotlib交互式数据光标实现(mplcursors)
2021/01/13 Python
法学毕业生自我鉴定
2013/11/08 职场文书
医院实习介绍信
2014/01/12 职场文书
护士求职信范文
2014/05/24 职场文书
镇创先争优活动总结
2014/08/28 职场文书
酒店工程部主管岗位职责
2015/04/16 职场文书
预备党员入党感言
2015/08/01 职场文书