Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
SublimeText 2编译python出错的解决方法(The system cannot find the file specified)
Nov 27 Python
python定时检查某个进程是否已经关闭的方法
May 20 Python
python中urllib.unquote乱码的原因与解决方法
Apr 24 Python
用 Python 连接 MySQL 的几种方式详解
Apr 04 Python
Python cookbook(字符串与文本)针对任意多的分隔符拆分字符串操作示例
Apr 19 Python
python接口自动化测试之接口数据依赖的实现方法
Apr 26 Python
python多线程实现代码(模拟银行服务操作流程)
Jan 13 Python
Python 解决火狐浏览器不弹出下载框直接下载的问题
Mar 09 Python
Pycharm安装第三方库失败解决方案
Nov 17 Python
Python基于Webhook实现github自动化部署
Nov 28 Python
详解python中的异常捕获
Dec 15 Python
python实现的人脸识别打卡系统
May 08 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
PHP获取mysql数据表的字段名称和详细信息的方法
2014/09/27 PHP
Thinkphp多文件上传实现方法
2014/10/31 PHP
PHP简单实现上一页下一页功能示例
2016/09/14 PHP
Document 对象的常用方法
2009/07/31 Javascript
EASYUI TREEGRID异步加载数据实现方法
2012/08/22 Javascript
浅谈javascript六种数据类型以及特殊注意点
2013/12/20 Javascript
模拟用户点击弹出新页面不会被浏览器拦截
2014/04/08 Javascript
浅析javascript操作 cookie对象
2014/12/26 Javascript
jquery实现鼠标滑过小图查看大图的方法
2015/07/20 Javascript
原生JavaScript实现动态省市县三级联动下拉框菜单实例代码
2016/02/03 Javascript
Bootstrap学习笔记之css组件(3)
2016/06/07 Javascript
详细总结Javascript中的焦点管理
2016/09/17 Javascript
JS正则验证多个邮箱完整实例【邮箱用分号隔开】
2017/04/19 Javascript
在Vue组件化中利用axios处理ajax请求的使用方法
2017/08/25 Javascript
快速搭建vue2.0+boostrap项目的方法
2018/04/09 Javascript
Vue.js 通过jQuery ajax获取数据实现更新后重新渲染页面的方法
2018/08/09 jQuery
three.js实现炫酷的全景3D重力感应
2018/12/30 Javascript
C#程序员入门学习微信小程序的笔记
2019/03/05 Javascript
Node.js Event Loop各阶段讲解
2019/03/08 Javascript
解决vue axios跨域 Request Method: OPTIONS问题(预检请求)
2020/08/14 Javascript
[41:12]Liquid vs Secret 2019国际邀请赛淘汰赛 败者组 BO3 第一场 8.24
2019/09/10 DOTA
对python 树状嵌套结构的实现思路详解
2019/08/09 Python
中国宠物用品商城:E宠商城
2016/08/27 全球购物
澳大利亚购买最佳炊具品牌网站:Cookware Brands
2019/02/16 全球购物
馥绿德雅美国官方网站:Rene Furterer头皮护理专家
2019/05/01 全球购物
英国钻石公司:British Diamond Company
2020/02/16 全球购物
Nike俄罗斯官方网站:Nike RU
2021/03/05 全球购物
思想汇报格式
2014/01/05 职场文书
会计电算化个人求职信范文
2014/01/24 职场文书
创先争优活动方案
2014/02/12 职场文书
艺术学院毕业生自荐信
2014/07/05 职场文书
银行反四风对照检查材料
2014/09/29 职场文书
《家》读后感:万惜拯救,冷暖自知
2019/09/25 职场文书
创业计划书之密室逃脱
2019/11/08 职场文书
撤回我也能看到!教你用Python制作微信防撤回脚本
2021/06/11 Python
Python常用配置文件ini、json、yaml读写总结
2021/07/09 Python