Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python用zip函数同时遍历多个迭代器示例详解
Nov 14 Python
python 数据清洗之数据合并、转换、过滤、排序
Feb 12 Python
Python变量赋值的秘密分享
Apr 03 Python
python pygame模块编写飞机大战
Nov 20 Python
在python中获取div的文本内容并和想定结果进行对比详解
Jan 02 Python
Python基于datetime或time模块分别获取当前时间戳的方法实例
Feb 19 Python
python实现简单聊天室功能 可以私聊
Jul 12 Python
Django如何使用第三方服务发送电子邮件
Aug 14 Python
用python求一重积分和二重积分的例子
Dec 06 Python
python selenium实现发送带附件的邮件代码实例
Dec 10 Python
pytorch 数据处理:定义自己的数据集合实例
Dec 31 Python
音频处理 windows10下python三方库librosa安装教程
Jun 20 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
php项目打包方法
2008/02/18 PHP
双击滚屏-常用推荐
2006/11/29 Javascript
ImageFlow可鼠标控制图片滚动
2008/01/30 Javascript
基于jQuery试卷自动排版系统
2010/07/18 Javascript
javascript 事件处理程序介绍
2012/06/27 Javascript
JSON+JavaScript处理JSON的简单例子
2013/03/20 Javascript
JavaScript面向对象编程入门教程
2014/04/16 Javascript
JS实现div居中示例
2014/04/17 Javascript
ECMAScript6的新特性箭头函数(Arrow Function)详细介绍
2014/06/07 Javascript
Vue.js每天必学之内部响应式原理探究
2016/09/07 Javascript
js 转json格式的字符串为对象或数组(前后台)的方法
2016/11/02 Javascript
AngularJS使用ng-repeat和ng-if实现数据的删选显示效果示例【适用于表单数据的显示】
2016/12/13 Javascript
vue实现ajax滚动下拉加载,同时具有loading效果(推荐)
2017/01/11 Javascript
JS实现自定义状态栏动画文字效果示例
2017/10/12 Javascript
JS在Array数组中按指定位置删除或添加元素对象方法示例
2019/11/19 Javascript
微信小程序实现搜索框功能及踩过的坑
2020/06/19 Javascript
基于vue+element实现全局loading过程详解
2020/07/10 Javascript
[00:36]DOTA2风云人物相约完美“圣”典 12月17日不见不散
2016/11/30 DOTA
在Django的视图中使用数据库查询的方法
2015/07/16 Python
Python 详解基本语法_函数_返回值
2017/01/22 Python
python模拟事件触发机制详解
2018/01/19 Python
python pandas库的安装和创建
2019/01/10 Python
pytorch 模型可视化的例子
2019/08/17 Python
Python求平面内点到直线距离的实现
2020/01/19 Python
使用Keras预训练好的模型进行目标类别预测详解
2020/06/27 Python
使用Python封装excel操作指南
2021/01/29 Python
美国二手奢侈品寄售网站:TheRealReal
2016/10/29 全球购物
函授大专自我鉴定
2013/11/01 职场文书
春节联欢会主持词
2014/03/24 职场文书
法院先进个人事迹材料
2014/05/04 职场文书
公司委托书怎么写
2014/08/02 职场文书
学校2014重阳节活动策划方案
2014/09/16 职场文书
党的群众路线教育实践活动个人对照检查材料(公安)
2014/11/05 职场文书
25句企业管理语录:助你迅速打开思路,句句经典!
2020/01/14 职场文书
详解PHP服务器如何在有限的资源里最大提升并发能力
2021/05/25 PHP
PHP中多字节字符串操作实例详解
2021/08/23 PHP