Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python捕捉和模拟鼠标事件的方法
Jun 03 Python
在Django的form中使用CSS进行设计的方法
Jul 18 Python
tensorflow saver 保存和恢复指定 tensor的实例讲解
Jul 26 Python
python按照多个条件排序的方法
Feb 08 Python
Python实现读取txt文件中的数据并绘制出图形操作示例
Feb 26 Python
python时间序列按频率生成日期的方法
May 14 Python
解决pycharm下os.system执行命令返回有中文乱码的问题
Jul 07 Python
对Python函数设计规范详解
Jul 19 Python
python多进程并行代码实例
Sep 30 Python
Python创建一个元素都为0的列表实例
Nov 28 Python
python通过文本在一个图中画多条线的实例
Feb 21 Python
Django使用echarts进行可视化展示的实践
Jun 10 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
如何利用php+mysql保存和输出文件
2006/10/09 PHP
fleaphp rolesNameField bug解决方法
2011/04/23 PHP
php实现图片缩略图的方法
2016/03/29 PHP
PHP微信红包生成代码分享
2016/10/06 PHP
5 cool javascript apps
2007/03/24 Javascript
只需20行代码就可以写出CSS覆盖率测试脚本
2013/04/24 Javascript
JavaScript中获取HTML元素值的三种方法
2016/06/20 Javascript
bootstrap输入框组件使用方法详解
2017/01/19 Javascript
JavaScript如何一次性展示几万条数据
2017/03/30 Javascript
js防刷新的倒计时代码 js倒计时代码
2017/09/06 Javascript
vue+vue-validator 表单验证功能的实现代码
2017/11/13 Javascript
Vue仿支付宝支付功能
2018/05/25 Javascript
vue实现简单loading进度条
2018/06/06 Javascript
JS数据类型STRING使用实例解析
2019/12/18 Javascript
JS实现手写 forEach算法示例
2020/04/29 Javascript
浅谈React中组件逻辑复用的那些事儿
2020/05/21 Javascript
浅谈vue生命周期共有几个阶段?分别是什么?
2020/08/07 Javascript
vue 调用 RESTful风格接口操作
2020/08/11 Javascript
python实现查询苹果手机维修进度
2015/03/16 Python
python opencv设置摄像头分辨率以及各个参数的方法
2018/04/02 Python
django 修改server端口号的方法
2018/05/14 Python
python中使用iterrows()对dataframe进行遍历的实例
2018/06/09 Python
详解Python字符串切片
2019/05/20 Python
浅谈pytorch、cuda、python的版本对齐问题
2020/01/15 Python
python读取yaml文件后修改写入本地实例
2020/04/27 Python
台湾前三大B2C购物网站:MOMO购物网
2017/04/27 全球购物
蒙蒂塞罗商店:Monticello Shop
2018/11/25 全球购物
马来西亚奢侈品牌购物商城:Valiram 247
2020/09/29 全球购物
结构和类有什么异同
2012/07/16 面试题
宠物店的创业计划书范文
2014/01/11 职场文书
中秋节主持词
2014/04/02 职场文书
2014年惩防体系建设工作总结
2014/12/01 职场文书
幼儿园校车安全责任书
2015/05/08 职场文书
使用python如何删除同一文件夹下相似的图片
2021/05/07 Python
MySQL索引失效的典型案例
2021/06/05 MySQL
剑指Offer之Java算法习题精讲二叉树专项训练
2022/03/21 Java/Android