Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现爬取知乎神回复简单爬虫代码分享
Jan 04 Python
在Python的web框架中编写创建日志的程序的教程
Apr 30 Python
Python冒泡排序注意要点实例详解
Sep 09 Python
python利用OpenCV2实现人脸检测
Apr 16 Python
python实现list由于numpy array的转换
Apr 04 Python
详谈python3中用for循环删除列表中元素的坑
Apr 19 Python
Python函数参数操作详解
Aug 03 Python
根据tensor的名字获取变量的值方式
Jan 04 Python
什么是python类属性
Jun 10 Python
浅谈Python协程
Jun 17 Python
爬虫代理的cookie如何生成运行
Sep 22 Python
Python学习开发之图形用户界面详解
Aug 23 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
收音机频率指针指示不准确和灵敏度低问题
2021/03/02 无线电
php设计模式 DAO(数据访问对象模式)
2011/06/26 PHP
PHP实现通过中文字符比率来判断垃圾评论的方法
2014/10/20 PHP
使用composer命令加载vendor中的第三方类库 的方法
2019/07/09 PHP
JQuery.uploadify 上传文件插件的使用详解 for ASP.NET
2010/01/22 Javascript
javascript简易缓动插件(源码打包)
2012/02/16 Javascript
对于this和$(this)的个人理解
2013/09/08 Javascript
JS格式化数字保留两位小数点示例代码
2013/10/15 Javascript
图片放大镜jquery.jqzoom.js使用实例附放大镜图标
2014/06/19 Javascript
js自动生成的元素与页面原有元素发生堆叠的解决方法
2014/09/04 Javascript
jQuery插件Tmpl的简单使用方法
2015/04/27 Javascript
jQuery leonaScroll 1.1 自定义滚动条插件(推荐)
2016/09/17 Javascript
聊聊JS动画库 Velocity.js的使用
2018/03/13 Javascript
在vue中使用echarts图表实例代码详解
2018/10/22 Javascript
VUE DEMO之模拟登录个人中心页面之间数据传值实例
2019/10/31 Javascript
js正则匹配多个全部数据问题
2019/12/20 Javascript
微信小程序实现录制、试听、上传音频功能(带波形图)
2020/02/27 Javascript
[06:16]DOTA2守卫传承者——职业选手谈心路历程
2015/02/26 DOTA
利用soaplib搭建webservice详细步骤和实例代码
2013/11/20 Python
Python浅拷贝与深拷贝用法实例
2015/05/09 Python
python 拼接文件路径的方法
2018/10/23 Python
使用python将mysql数据库的数据转换为json数据的方法
2019/07/01 Python
python读出当前时间精度到秒的代码
2019/07/05 Python
通过实例解析Python return运行原理
2020/03/04 Python
Python HTTP下载文件并显示下载进度条功能的实现
2020/04/02 Python
如何使用Django Admin管理后台导入CSV
2020/11/06 Python
荷兰优雅女装网上商店:Heine
2016/11/14 全球购物
大学生标准推荐信范文
2013/11/25 职场文书
优秀毕业生求职信
2014/06/05 职场文书
经营理念标语
2014/06/21 职场文书
安全生产工作汇报
2014/10/28 职场文书
办公用品管理制度
2015/08/04 职场文书
2016年中秋节慰问信
2015/12/01 职场文书
煤矿安全生产管理协议书
2016/03/22 职场文书
golang 实现并发求和
2021/05/08 Golang
Elasticsearch 批量操作
2022/04/19 Python