使用keras实现孪生网络中的权值共享教程


Posted in Python onJune 11, 2020

首先声明,这里的权值共享指的不是CNN原理中的共享权值,而是如何在构建类似于Siamese Network这样的多分支网络,且分支结构相同时,如何使用keras使分支的权重共享。

Functional API

为达到上述的目的,建议使用keras中的Functional API,当然Sequential 类型的模型也可以使用,本篇博客将主要以Functional API为例讲述。

keras的多分支权值共享功能实现,官方文档介绍

上面是官方的链接,本篇博客也是基于上述官方文档,实现的此功能。(插一句,keras虽然有中文文档,但中文文档已停更,且中文文档某些函数介绍不全,建议直接看英文官方文档)

不共享参数的模型

以MatchNet网络结构为例子,为方便显示,将卷积模块个数减为2个。首先是展示不共享参数的模型,以便观看完整的网络结构。

整体的网络结构如下所示:

代码包含两部分,第一部分定义了两个函数,FeatureNetwork()生成特征提取网络,ClassiFilerNet()生成决策网络或称度量网络。网络结构的可视化在博客末尾。在ClassiFilerNet()函数中,可以看到调用了两次FeatureNetwork()函数,keras.models.Model也被使用的两次,因此生成的input1和input2是两个完全独立的模型分支,参数是不共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ---------------------函数功能区-------------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""
  input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
  input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
  for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
    layer.name = layer.name + str("_2")
  inp1 = input1.input
  inp2 = input2.input
  merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

# ---------------------主调区-------------------------
matchnet = ClassiFilerNet()
matchnet.summary() # 打印网络结构
plot_model(matchnet, to_file='G:/csdn攻略/picture/model.png') # 网络结构输出成png图片

共享参数的模型

FeatureNetwork()的功能和上面的功能相同,为方便选择,在ClassiFilerNet()函数中加入了判断是否使用共享参数模型功能,令reuse=True,便使用的是共享参数的模型。

关键地方就在,只使用的一次Model,也就是说只创建了一次模型,虽然输入了两个输入,但其实使用的是同一个模型,因此权重共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ----------------函数功能区-----------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
  # models = Activation('relu')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(reuse=False): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""

  if reuse:
    inp = Input(shape=(28, 28, 1), name='FeatureNet_ImageInput')
    models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
    models = Activation('relu')(models)
    models = MaxPool2D(pool_size=(3, 3))(models)

    models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
    # models = Activation('relu')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Flatten()(models)
    models = Dense(512)(models)
    models = Activation('relu')(models)
    model = Model(inputs=inp, outputs=models)

    inp1 = Input(shape=(28, 28, 1)) # 创建输入
    inp2 = Input(shape=(28, 28, 1)) # 创建输入2
    model_1 = model(inp1) # 孪生网络中的一个特征提取分支
    model_2 = model(inp2) # 孪生网络中的另一个特征提取分支
    merge_layers = concatenate([model_1, model_2]) # 进行融合,使用的是默认的sum,即简单的相加

  else:
    input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
    input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
    for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
      layer.name = layer.name + str("_2")
    inp1 = input1.input
    inp2 = input2.input
    merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

如何看是否真的是权值共享呢?直接对比特征提取部分的网络参数个数!

不共享参数模型的参数数量:

使用keras实现孪生网络中的权值共享教程

共享参数模型的参数总量

使用keras实现孪生网络中的权值共享教程

共享参数模型中的特征提取部分的参数量为:

使用keras实现孪生网络中的权值共享教程

由于截图限制,不共享参数模型的特征提取网络参数数量不再展示。其实经过计算,特征提取网络部分的参数数量,不共享参数模型是共享参数的两倍。两个网络总参数量的差值就是,共享模型中,特征提取部分的参数的量

网络结构可视化

不共享权重的网络结构

使用keras实现孪生网络中的权值共享教程

共享参数的网络结构,其中model_1代表的就是特征提取部分。

使用keras实现孪生网络中的权值共享教程

以上这篇使用keras实现孪生网络中的权值共享教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python获得时间的实例说明
Mar 25 Python
Python脚本实现集群检测和管理功能
Mar 06 Python
Python实现给qq邮箱发送邮件的方法
May 28 Python
举例讲解Python中的身份运算符的使用方法
Oct 13 Python
Python中对象迭代与反迭代的技巧总结
Sep 17 Python
python 实现返回一个列表中出现次数最多的元素方法
Jun 11 Python
python能做什么 python的含义
Oct 12 Python
python 实现将小图片放到另一个较大的白色或黑色背景图片中
Dec 12 Python
在Pytorch中计算卷积方法的区别详解(conv2d的区别)
Jan 03 Python
浅析matlab中imadjust函数
Feb 27 Python
Python中X[:,0]和X[:,1]的用法
May 10 Python
Python爬虫基础讲解之请求
May 13 Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
keras load model时出现Missing Layer错误的解决方式
Jun 11 #Python
Pyinstaller加密打包应用的示例代码
Jun 11 #Python
解决keras加入lambda层时shape的问题
Jun 11 #Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 #Python
You might like
英雄试炼之肉山谷—引领RPG新潮流
2020/04/20 DOTA
Apache 配置详解(最好的APACHE配置教程)
2010/07/04 PHP
php获取淘宝分类id示例
2014/01/16 PHP
Js 获取当前日期时间及其它操作实现代码
2021/03/04 Javascript
JavaScript 权威指南(第四版) 读书笔记
2009/08/11 Javascript
javascript原生和jquery库实现iframe自适应高度和宽度
2014/07/18 Javascript
js使用removeChild方法动态删除div元素
2014/08/01 Javascript
cocos2dx骨骼动画Armature源码剖析(一)
2015/09/08 Javascript
jQuery Easyui 验证两次密码输入是否相等
2016/05/13 Javascript
jQuery获取select选中的option的value值实现方法
2016/08/29 Javascript
JavaScript中的ajax功能的概念和示例详解
2016/10/17 Javascript
JS动态添加选项案例分析
2016/10/17 Javascript
jquery ui sortable拖拽后保存位置
2017/04/27 jQuery
Js中async/await的执行顺序详解
2017/09/22 Javascript
详解webpack-dev-server 设置反向代理解决跨域问题
2018/04/18 Javascript
JS使用Date对象实时显示当前系统时间简单示例
2018/08/23 Javascript
vue实现点击选中,其他的不选中方法
2018/09/05 Javascript
基于vue 实现表单中password输入的显示与隐藏功能
2019/07/19 Javascript
超详细的5个Shell脚本实例分享(值得收藏)
2019/08/15 Javascript
[01:09]DOTA2次级职业联赛 - 99战队宣传片
2014/12/01 DOTA
python 多线程应用介绍
2012/12/19 Python
Python实现加载及解析properties配置文件的方法
2018/03/29 Python
PyCharm设置SSH远程调试的方法
2018/07/17 Python
Python何时应该使用Lambda函数
2019/07/02 Python
Python绘图之二维图与三维图详解
2020/08/04 Python
一篇文章带你搞定Ubuntu中打开Pycharm总是卡顿崩溃
2020/11/02 Python
Python实现随机爬山算法
2021/01/29 Python
玩转CSS3色彩
2010/01/16 HTML / CSS
利用纯html5绘制出来的一款非常漂亮的时钟
2015/01/04 HTML / CSS
兰兰过桥教学反思
2014/02/08 职场文书
机电系毕业生求职信
2014/07/11 职场文书
财政局党的群众路线教育实践活动整改方案
2014/09/21 职场文书
2015年社区服务活动总结
2015/03/25 职场文书
安全事故隐患排查治理制度
2015/08/05 职场文书
反邪教教育心得体会
2016/01/15 职场文书
同学会演讲稿
2019/04/02 职场文书