使用keras实现孪生网络中的权值共享教程


Posted in Python onJune 11, 2020

首先声明,这里的权值共享指的不是CNN原理中的共享权值,而是如何在构建类似于Siamese Network这样的多分支网络,且分支结构相同时,如何使用keras使分支的权重共享。

Functional API

为达到上述的目的,建议使用keras中的Functional API,当然Sequential 类型的模型也可以使用,本篇博客将主要以Functional API为例讲述。

keras的多分支权值共享功能实现,官方文档介绍

上面是官方的链接,本篇博客也是基于上述官方文档,实现的此功能。(插一句,keras虽然有中文文档,但中文文档已停更,且中文文档某些函数介绍不全,建议直接看英文官方文档)

不共享参数的模型

以MatchNet网络结构为例子,为方便显示,将卷积模块个数减为2个。首先是展示不共享参数的模型,以便观看完整的网络结构。

整体的网络结构如下所示:

代码包含两部分,第一部分定义了两个函数,FeatureNetwork()生成特征提取网络,ClassiFilerNet()生成决策网络或称度量网络。网络结构的可视化在博客末尾。在ClassiFilerNet()函数中,可以看到调用了两次FeatureNetwork()函数,keras.models.Model也被使用的两次,因此生成的input1和input2是两个完全独立的模型分支,参数是不共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ---------------------函数功能区-------------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""
  input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
  input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
  for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
    layer.name = layer.name + str("_2")
  inp1 = input1.input
  inp2 = input2.input
  merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

# ---------------------主调区-------------------------
matchnet = ClassiFilerNet()
matchnet.summary() # 打印网络结构
plot_model(matchnet, to_file='G:/csdn攻略/picture/model.png') # 网络结构输出成png图片

共享参数的模型

FeatureNetwork()的功能和上面的功能相同,为方便选择,在ClassiFilerNet()函数中加入了判断是否使用共享参数模型功能,令reuse=True,便使用的是共享参数的模型。

关键地方就在,只使用的一次Model,也就是说只创建了一次模型,虽然输入了两个输入,但其实使用的是同一个模型,因此权重共享的。

from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model

# ----------------函数功能区-----------------------
def FeatureNetwork():
  """生成特征提取网络"""
  """这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
  inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
  models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
  models = Activation('relu')(models)
  models = MaxPool2D(pool_size=(3, 3))(models)

  models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
  models = Activation('relu')(models)

  # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
  # models = Activation('relu')(models)
  # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
  models = Flatten()(models)
  models = Dense(512)(models)
  models = Activation('relu')(models)
  model = Model(inputs=inp, outputs=models)
  return model

def ClassiFilerNet(reuse=False): # add classifier Net
  """生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""

  if reuse:
    inp = Input(shape=(28, 28, 1), name='FeatureNet_ImageInput')
    models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
    models = Activation('relu')(models)
    models = MaxPool2D(pool_size=(3, 3))(models)

    models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
    models = Activation('relu')(models)

    # models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
    # models = Activation('relu')(models)
    # models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
    models = Flatten()(models)
    models = Dense(512)(models)
    models = Activation('relu')(models)
    model = Model(inputs=inp, outputs=models)

    inp1 = Input(shape=(28, 28, 1)) # 创建输入
    inp2 = Input(shape=(28, 28, 1)) # 创建输入2
    model_1 = model(inp1) # 孪生网络中的一个特征提取分支
    model_2 = model(inp2) # 孪生网络中的另一个特征提取分支
    merge_layers = concatenate([model_1, model_2]) # 进行融合,使用的是默认的sum,即简单的相加

  else:
    input1 = FeatureNetwork()           # 孪生网络中的一个特征提取
    input2 = FeatureNetwork()           # 孪生网络中的另一个特征提取
    for layer in input2.layers:          # 这个for循环一定要加,否则网络重名会出错。
      layer.name = layer.name + str("_2")
    inp1 = input1.input
    inp2 = input2.input
    merge_layers = concatenate([input1.output, input2.output])    # 进行融合,使用的是默认的sum,即简单的相加
  fc1 = Dense(1024, activation='relu')(merge_layers)
  fc2 = Dense(1024, activation='relu')(fc1)
  fc3 = Dense(2, activation='softmax')(fc2)

  class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
  return class_models

如何看是否真的是权值共享呢?直接对比特征提取部分的网络参数个数!

不共享参数模型的参数数量:

使用keras实现孪生网络中的权值共享教程

共享参数模型的参数总量

使用keras实现孪生网络中的权值共享教程

共享参数模型中的特征提取部分的参数量为:

使用keras实现孪生网络中的权值共享教程

由于截图限制,不共享参数模型的特征提取网络参数数量不再展示。其实经过计算,特征提取网络部分的参数数量,不共享参数模型是共享参数的两倍。两个网络总参数量的差值就是,共享模型中,特征提取部分的参数的量

网络结构可视化

不共享权重的网络结构

使用keras实现孪生网络中的权值共享教程

共享参数的网络结构,其中model_1代表的就是特征提取部分。

使用keras实现孪生网络中的权值共享教程

以上这篇使用keras实现孪生网络中的权值共享教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在ironpython中利用装饰器执行SQL操作的例子
May 02 Python
python实现壁纸批量下载代码实例
Jan 25 Python
python设计tcp数据包协议类的例子
Jul 23 Python
django中使用Celery 布式任务队列过程详解
Jul 29 Python
python求最大公约数和最小公倍数的简单方法
Feb 13 Python
Python的Django框架实现数据库查询(不返回QuerySet的方法)
May 19 Python
基于Keras中Conv1D和Conv2D的区别说明
Jun 19 Python
Pytho爬虫中Requests设置请求头Headers的方法
Sep 22 Python
如何基于pandas读取csv后合并两个股票
Sep 25 Python
matplotlib自定义鼠标光标坐标格式的实现
Jan 08 Python
教你用Python matplotlib库制作简单的动画
Jun 11 Python
python分分钟绘制精美地图海报
Feb 15 Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
keras load model时出现Missing Layer错误的解决方式
Jun 11 #Python
Pyinstaller加密打包应用的示例代码
Jun 11 #Python
解决keras加入lambda层时shape的问题
Jun 11 #Python
python opencv把一张图片嵌入(叠加)到另一张图片上的实现代码
Jun 11 #Python
You might like
openflashchart 2.0 简单案例php版
2012/05/21 PHP
大家在抢红包,程序员在研究红包算法
2015/08/31 PHP
PHP使用PHPexcel导入导出数据的方法
2015/11/14 PHP
php使用curl并发减少后端访问时间的方法分析
2016/05/12 PHP
PHP中用mysqli面向对象打开连接关闭mysql数据库的方法
2016/11/05 PHP
快速解决PHP调用Word组件DCOM权限的问题
2017/12/27 PHP
PHP使用Redis长连接的方法详解
2018/02/12 PHP
利用onresize使得div可以随着屏幕大小而自适应的代码
2010/01/15 Javascript
chrome浏览器不支持onmouseleave事件的解决技巧
2013/05/31 Javascript
jQuery操作元素css样式的三种方法
2014/06/04 Javascript
JavaScript函数详解
2014/11/17 Javascript
轻松创建nodejs服务器(10):处理上传图片
2014/12/18 NodeJs
使用Plupload实现直接上传附件至七牛云存储
2014/12/26 Javascript
jquery动态改变div宽度和高度
2015/02/09 Javascript
AngularJS基础学习笔记之控制器
2015/05/10 Javascript
JavaScript检查子字符串是否在字符串中的方法
2016/02/03 Javascript
AngularJS通过$location获取及改变当前页面的URL
2016/09/23 Javascript
vue组件父与子通信详解(一)
2017/11/07 Javascript
通过layer实现可输入的模态框的例子
2019/09/27 Javascript
[01:34]DAC2018主赛事第四日五佳镜头 Gh巨牙海民助Miracle-死里逃生
2018/04/07 DOTA
python文件写入实例分析
2015/04/08 Python
基于Python3 逗号代码 和 字符图网格(详谈)
2017/06/22 Python
Python 画出来六维图
2019/07/26 Python
学python安装的软件总结
2019/10/12 Python
Canvas 帧动画吃苹果小游戏
2020/08/05 HTML / CSS
欧洲最大的品牌水上运动服装和设备在线零售商:Wuituit Outlet
2018/05/05 全球购物
法国足球商店:Footcenter
2019/07/06 全球购物
莫斯科制造商的廉价皮大衣:Fursk
2020/06/09 全球购物
临床医学专业毕业生的自我评价
2013/10/17 职场文书
国际贸易专业个人鉴定
2014/02/22 职场文书
《蜗牛的奖杯》教后反思
2014/04/24 职场文书
企业员工辞职信范文
2015/05/12 职场文书
职工培训工作总结
2015/08/10 职场文书
Css预编语言及区别详解
2021/04/25 HTML / CSS
用python修改excel表某一列内容的操作方法
2021/06/11 Python
Java中生成微信小程序太阳码的实现方案
2022/06/01 Java/Android