Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中pygame针对游戏窗口的显示方法实例分析(附源码)
Nov 11 Python
python套接字流重定向实例汇总
Mar 03 Python
python解决网站的反爬虫策略总结
Oct 26 Python
http请求 request失败自动重新尝试代码示例
Jan 25 Python
python使用Tesseract库识别验证
Mar 21 Python
对python中的logger模块全面讲解
Apr 28 Python
用python写扫雷游戏实例代码分享
May 27 Python
python 反向输出字符串的方法
Jul 16 Python
Python实现的tcp端口检测操作示例
Jul 24 Python
在PyTorch中使用标签平滑正则化的问题
Apr 03 Python
基于Keras 循环训练模型跑数据时内存泄漏的解决方式
Jun 11 Python
音频处理 windows10下python三方库librosa安装教程
Jun 20 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
php中常用字符串处理代码片段整理
2011/11/07 PHP
PHP fopen()和 file_get_contents()应用与差异介绍
2014/03/19 PHP
ThinkPHP连接数据库及主从数据库的设置教程
2014/08/22 PHP
浅析PHP文件下载原理
2014/12/25 PHP
php计算一个文件大小的方法
2015/03/30 PHP
深入理解PHP 数组之count 函数
2016/06/13 PHP
javascript eval函数深入认识
2009/02/21 Javascript
iframe的onload在Chrome/Opera中执行两次Bug的解决方法
2011/03/17 Javascript
jquery获取下拉列表的值为null的解决方法
2011/03/18 Javascript
Javascript获取当前日期的农历日期代码
2014/10/08 Javascript
详解JavaScript中的every()方法
2015/06/08 Javascript
js实现的Easy Tabs选项卡用法实例
2015/09/06 Javascript
JQuery的常用选择器、过滤器、方法全面介绍
2016/05/25 Javascript
jQuery 3.0十大新特性
2016/07/06 Javascript
javascript设计模式Constructor(构造器)模式
2016/08/19 Javascript
AngularJs表单校验功能实例代码
2017/02/09 Javascript
JavaScript实现淘宝京东6位数字支付密码效果
2018/08/18 Javascript
vue 插槽简介及使用示例
2020/11/19 Vue.js
[04:26]2014DOTA2西雅图国际邀请赛 总决赛TOPPLAY
2014/07/22 DOTA
python中的多重继承实例讲解
2014/09/28 Python
在Python中操作时间之strptime()方法的使用
2020/12/30 Python
Python连接PostgreSQL数据库的方法
2016/11/28 Python
Golang与python线程详解及简单实例
2017/04/27 Python
python中的decorator的作用详解
2018/07/26 Python
Python制作动态字符图的实例
2019/01/27 Python
python的移位操作实现详解
2019/08/21 Python
深入了解如何基于Python读写Kafka
2019/12/31 Python
如何使用Python调整图像大小
2020/09/26 Python
python tkinter的消息框模块(messagebox,simpledialog)
2020/11/07 Python
详解CSS3中常用的样式【基本文本和字体样式】
2020/10/20 HTML / CSS
先进事迹报告会感言
2014/01/24 职场文书
设备管理实施方案
2014/05/31 职场文书
苏州园林导游词
2015/02/03 职场文书
父亲去世追悼词
2015/06/23 职场文书
合作合同协议书
2016/03/21 职场文书
Mysql排序的特性详情
2021/11/01 MySQL