Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python base64编码解码实例
Jun 21 Python
Python实现ssh批量登录并执行命令
Oct 25 Python
运动检测ViBe算法python实现代码
Jan 09 Python
在python带权重的列表中随机取值的方法
Jan 23 Python
简单了解python协程的相关知识
Aug 31 Python
python处理excel绘制雷达图
Oct 18 Python
Django对接支付宝实现支付宝充值金币功能示例
Dec 17 Python
Django 多对多字段的更新和插入数据实例
Mar 31 Python
django和flask哪个值得研究学习
Jul 31 Python
Python 实现进度条的六种方式
Jan 06 Python
Pytorch - TORCH.NN.INIT 参数初始化的操作
Feb 27 Python
python工具dtreeviz决策树可视化和模型可解释性
Mar 03 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
PHP的FTP学习(二)[转自奥索]
2006/10/09 PHP
不重新编译PHP为php增加openssl模块的方法
2011/06/14 PHP
php calender(日历)二个版本代码示例(解决2038问题)
2013/12/24 PHP
php array_merge函数使用需要注意的一个问题
2015/03/30 PHP
PHP线程的内存回收问题
2016/07/08 PHP
PHP编程实现微信企业向用户付款的方法示例
2017/07/26 PHP
PHP实现Snowflake生成分布式唯一ID的方法示例
2020/08/30 PHP
Javascript - HTML的request类
2006/07/15 Javascript
url 特殊字符 传递参数解决方法
2010/01/01 Javascript
理清apply(),call()的区别和关系
2011/08/14 Javascript
js取滚动条的尺寸的函数代码
2011/11/30 Javascript
JavaScript中两个感叹号的作用说明
2011/12/28 Javascript
利用NodeJS的子进程(child_process)调用系统命令的方法分享
2013/06/05 NodeJs
简单的JavaScript互斥锁分享
2014/02/02 Javascript
jQuery Ajax()方法使用指南
2014/11/19 Javascript
javascript使用call调用微信API
2014/12/15 Javascript
jQuery实现360°全景拖动展示
2015/03/18 Javascript
jQuery zclip插件实现跨浏览器复制功能
2015/11/02 Javascript
WEB前端实现裁剪上传图片功能
2016/10/17 Javascript
JS实现JSON.stringify的实例代码讲解
2017/02/07 Javascript
基于bootstrap按钮式下拉菜单组件的搜索建议插件
2017/03/25 Javascript
解决vue router使用 history 模式刷新后404问题
2017/07/19 Javascript
vue select二级联动第二级默认选中第一个option值的实例
2018/01/10 Javascript
React Navigation 使用中遇到的问题小结
2018/05/08 Javascript
jQuery动态移除与增加onclick属性的方法详解
2018/06/07 jQuery
了不起的11个JavaScript代码重构最佳实践小结
2021/01/11 Javascript
[02:40]DOTA2英雄基础教程 炼金术士
2013/12/23 DOTA
[02:11]DOTA2上海特级锦标赛主赛事第二日RECAP
2016/03/04 DOTA
[53:15]Newbee vs Pain 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
在Mac OS上部署Nginx和FastCGI以及Flask框架的教程
2015/05/02 Python
Python2.x利用commands模块执行Linux shell命令
2016/03/11 Python
python实现百度语音识别api
2018/04/10 Python
Python2和Python3中@abstractmethod使用方法
2020/02/04 Python
教师党员先进性教育自我剖析材料思想汇报
2014/09/24 职场文书
晋江市人民政府党组群众路线教育实践活动整改方案
2014/10/25 职场文书
MySQL里面的子查询的基本使用
2021/08/02 MySQL