Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python线程池的实现实例
Nov 18 Python
使用python装饰器验证配置文件示例
Feb 24 Python
详解Django中的权限和组以及消息
Jul 23 Python
python 全局变量的import机制介绍
Sep 07 Python
python pandas.DataFrame选取、修改数据最好用.loc,.iloc,.ix实现
Jun 11 Python
Python使用logging模块实现打印log到指定文件的方法
Sep 05 Python
python re库的正则表达式入门学习教程
Mar 08 Python
Python 多线程C段扫描、检测 Ping扫描脚本的实现
Sep 03 Python
安装并免费使用Pycharm专业版(学生/教师)
Sep 24 Python
python两种获取剪贴板内容的方法
Nov 06 Python
python 将Excel转Word的示例
Mar 02 Python
分享Python获取本机IP地址的几种方法
Mar 17 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
短波问题解答
2021/02/28 无线电
学习php设计模式 php实现工厂模式(factory)
2015/12/07 PHP
php简单计算年龄的方法(周岁与虚岁)
2016/12/06 PHP
Thinkphp 框架基础之源码获取、环境要求与目录结构分析
2020/04/27 PHP
Avengerls vs KG BO3 第二场2.18
2021/03/10 DOTA
js 函数的执行环境和作用域链的深入解析
2009/11/01 Javascript
js中判断数字\字母\中文的正则表达式 (实例)
2012/06/29 Javascript
jquery 实现checkbox全选,反选,全不选等功能代码(奇数)
2012/10/24 Javascript
jQuery $.extend()用法总结
2014/06/15 Javascript
谷歌showModalDialog()方法不兼容出现对话窗口的解决办法
2016/02/15 Javascript
基于JQuery实现分隔条的功能
2016/06/17 Javascript
js初始化验证实例详解
2016/11/26 Javascript
nodejs 实现钉钉ISV接入的加密解密方法
2017/01/16 NodeJs
nodejs模块nodemailer基本使用-邮件发送示例(支持附件)
2017/03/28 NodeJs
Angular中自定义Debounce Click指令防止重复点击
2017/07/26 Javascript
Vue引用Swiper4插件无法重写分页器样式的解决方法
2018/09/27 Javascript
JavaScript递归函数定义与用法实例分析
2019/01/24 Javascript
如何用JavaScript实现功能齐全的单链表详解
2019/02/11 Javascript
在Python中关于中文编码问题的处理建议
2015/04/08 Python
一个基于flask的web应用诞生 记录用户账户登录状态(6)
2017/04/11 Python
让Django支持Sql Server作后端数据库的方法
2018/05/29 Python
python实现txt文件格式转换为arff格式
2018/05/31 Python
使用Python实现从各个子文件夹中复制指定文件的方法
2018/10/25 Python
使用Python操作FTP实现上传和下载的方法
2019/04/01 Python
python多进程并行代码实例
2019/09/30 Python
记录模型训练时loss值的变化情况
2020/06/16 Python
利用python制作拼图小游戏的全过程
2020/12/04 Python
python中random模块详解
2021/03/01 Python
琳达·法罗眼镜英国官网:Linda Farrow英国
2021/01/19 全球购物
会计应聘求职信范文
2013/12/17 职场文书
教师一帮一活动总结
2014/07/08 职场文书
加强作风建设心得体会
2014/10/22 职场文书
场地使用证明模板
2014/10/25 职场文书
2016婚礼主持词开场白
2015/11/24 职场文书
在CSS中使用when/else的方法
2022/01/18 HTML / CSS
java开发双人五子棋游戏
2022/05/06 Java/Android