Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python高效编程技巧
Jan 07 Python
Python通过poll实现异步IO的方法
Jun 04 Python
Python操作RabbitMQ服务器实现消息队列的路由功能
Jun 29 Python
python xml.etree.ElementTree遍历xml所有节点实例详解
Dec 04 Python
python3.6+django2.0开发一套学员管理系统
Mar 03 Python
Python json模块dumps、loads操作示例
Sep 06 Python
python选取特定列 pandas iloc,loc,icol的使用详解(列切片及行切片)
Aug 06 Python
基于django ManyToMany 使用的注意事项详解
Aug 09 Python
pymysql模块的使用(增删改查)详解
Sep 09 Python
django template实现定义临时变量,自定义赋值、自增实例
Jul 12 Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 Python
Python基础之进程详解
May 21 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
咖啡历史、消费和行业趋势
2021/03/03 咖啡文化
ThinkPHP模板IF标签用法详解
2014/07/01 PHP
php获取文章内容第一张图片的方法示例
2017/07/03 PHP
PHP PDOStatement::bindColumn讲解
2019/01/30 PHP
jquery写个checkbox——类似邮箱全选功能
2013/03/19 Javascript
jQuery学习笔记(3)--用jquery(插件)实现多选项卡功能
2013/04/08 Javascript
Jquery同辈元素选中/未选中效果的实例代码
2013/08/01 Javascript
jQuery实现简易的天天爱消除小游戏
2015/10/16 Javascript
浅析JS原型继承与类的继承
2016/04/07 Javascript
JS实现弹出居中的模式窗口示例
2016/06/20 Javascript
总结javascript中的六种迭代器
2016/08/16 Javascript
微信小程序网络请求实现过程解析
2019/11/06 Javascript
分享JS表单验证源码(带错误提示及密码等级)
2020/01/05 Javascript
jquery将信息遍历到界面上实例代码
2020/01/21 jQuery
node.js使用net模块创建服务器和客户端示例【基于TCP协议】
2020/02/14 Javascript
[59:15]完美世界DOTA2联赛PWL S2 LBZS vs FTD.C 第一场 11.20
2020/11/20 DOTA
python使用socket进行简单网络连接的方法
2015/04/29 Python
简单讲解Python中的字符串与字符串的输入输出
2016/03/13 Python
Python正则表达式分组概念与用法详解
2017/06/24 Python
详解python里使用正则表达式的分组命名方式
2017/10/24 Python
Python内置函数delattr的具体用法
2017/11/23 Python
TensorFlow实现简单的CNN的方法
2019/07/18 Python
Python单元测试与测试用例简析
2019/11/09 Python
pytorch 自定义卷积核进行卷积操作方式
2019/12/30 Python
python实现与redis交互操作详解
2020/04/21 Python
Keras - GPU ID 和显存占用设定步骤
2020/06/22 Python
REN Clean Skincare官网:英国本土有机护肤品牌
2019/02/23 全球购物
SmartBuyGlasses荷兰:购买太阳镜和眼镜
2020/03/16 全球购物
银行会计职员个人的自我评价
2013/09/29 职场文书
暑假社会实践心得体会
2014/09/02 职场文书
2015年“公民道德宣传日”活动方案
2015/05/06 职场文书
领导离职感言
2015/08/03 职场文书
销售人员管理制度
2015/08/06 职场文书
志愿者服务宣传标语口号
2015/12/26 职场文书
javascript Number 与 Math对象的介绍
2021/11/17 Javascript
Python tensorflow卷积神经Inception V3网络结构
2022/05/06 Python