Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python的urllib库提交WEB表单
Feb 24 Python
Python解析xml中dom元素的方法
Mar 12 Python
python解析xml文件实例分析
May 27 Python
python使用response.read()接收json数据的实例
Dec 19 Python
Python设置matplotlib.plot的坐标轴刻度间隔以及刻度范围
Jun 25 Python
django 2.2和mysql使用的常见问题
Jul 18 Python
简单介绍django提供的加密算法
Dec 18 Python
python自动下载图片的方法示例
Mar 25 Python
Django如何使用redis作为缓存
May 21 Python
pytorch  网络参数 weight bias 初始化详解
Jun 24 Python
如何教少儿学习Python编程
Jul 10 Python
python获取对象信息的实例详解
Jul 07 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
深入解析php模板技术原理【一】
2008/01/10 PHP
php中几种常见安全设置详解
2010/04/06 PHP
php中ob_get_length缓冲与获取缓冲长度实例
2014/11/20 PHP
php微信开发之自定义菜单实现
2016/11/18 PHP
[原创]php正则删除html代码中class样式属性的方法
2017/05/24 PHP
PHP连接及操作PostgreSQL数据库的方法详解
2019/01/30 PHP
javascript 写类方式之一
2009/07/05 Javascript
Javascript JSQL,SQL无处不在,
2010/05/05 Javascript
Javascript处理DOM元素事件实现代码
2012/05/23 Javascript
JavaScript实现表格排序方法
2013/06/14 Javascript
jQuery实现密保互斥问题解决方案
2013/08/16 Javascript
JavaScript支持的最大递归调用次数分析
2014/06/24 Javascript
node.js中的fs.readlinkSync方法使用说明
2014/12/17 Javascript
Javascript数组操作函数总结
2015/02/05 Javascript
JavaScript时间操作之年月日星期级联操作
2016/01/15 Javascript
Vue.js Ajax动态参数与列表显示实现方法
2016/10/20 Javascript
js实现选项卡内容切换以及折叠和展开效果【推荐】
2017/01/08 Javascript
vue 中的keep-alive实例代码
2018/07/20 Javascript
从0到1构建vueSSR项目之node以及vue-cli3的配置
2019/03/07 Javascript
vue+eslint+vscode配置教程
2019/08/09 Javascript
解决vue scoped html样式无效的问题
2020/10/24 Javascript
解决ant design vue 表格a-table二次封装,slots渲染的问题
2020/10/28 Javascript
Python 快速实现CLI 应用程序的脚手架
2017/12/05 Python
python web基础之加载静态文件实例
2018/03/20 Python
Python cookbook(字符串与文本)在字符串的开头或结尾处进行文本匹配操作
2018/04/20 Python
Django框架实现的普通登录案例【使用POST方法】
2019/05/15 Python
HTML5打开手机扫码功能及优缺点
2017/11/27 HTML / CSS
美国修容界大佬创建的个人美妆品牌:Kevyn Aucoin Beauty
2018/12/12 全球购物
PREMIUM-MALL法国:行李、箱包及配件在线
2019/05/30 全球购物
简述使用ftp进行文件传输时的两种登录方式?它们的区别是什么?常用的ftp文件传输命令是什么?
2016/11/20 面试题
档案接收函
2014/01/13 职场文书
工程造价专业大学生职业规划范文
2014/03/09 职场文书
毕业设计说明书
2014/05/07 职场文书
房地产销售主管岗位职责
2015/02/13 职场文书
2016国庆节67周年寄语
2015/12/07 职场文书
经典《舰娘》游改全新动画预告 预定11月开播
2022/04/01 日漫