Keras实现将两个模型连接到一起


Posted in Python onMay 23, 2020

神经网络玩得越久就越会尝试一些网络结构上的大改动。

先说意图

有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。

流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。

实现方法

首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。

第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。

第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。

可以看一个自编码器的代码(本人所编写):

class AE:
 def __init__(self, dim, img_dim, batch_size):
  self.dim = dim
  self.img_dim = img_dim
  self.batch_size = batch_size
  self.encoder = self.encoder_construct()
  self.decoder = self.decoder_construct()
 
 def encoder_construct(self):
  x_in = Input(shape=(self.img_dim, self.img_dim, 3))
  x = x_in
  x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
  x = BatchNormalization()(x)
  x = LeakyReLU(0.2)(x)
  x = GlobalAveragePooling2D()(x)
  encoder = Model(x_in, x)
  return encoder
 
 def decoder_construct(self):
  map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
  # print(type(map_size))
  z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
  z = z_in
  z_dim = self.dim
  z = Dense(np.prod(map_size) * z_dim)(z)
  z = Reshape(map_size + (z_dim,))(z)
  z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = BatchNormalization()(z)
  z = Activation('relu')(z)
  z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
  z = Activation('tanh')(z)
  decoder = Model(z_in, z)
  return decoder
 
 def build_ae(self):
  input_x = Input(shape=(self.img_dim, self.img_dim, 3))
  x = input_x
  for i in range(1, len(self.encoder.layers)):
   x = self.encoder.layers[i](x)
  for j in range(1, len(self.decoder.layers)):
   x = self.decoder.layers[j](x)
  y = x
  auto_encoder = Model(input_x, y)
  return auto_encoder

模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。

补充知识:keras得到每层的系数

使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:

weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)

这样系数就被存放到一个np中了。

以上这篇Keras实现将两个模型连接到一起就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接sql server乱码的解决方法
Jan 28 Python
pycharm 使用心得(三)Hello world!
Jun 05 Python
零基础写python爬虫之urllib2使用指南
Nov 05 Python
Python闭包实现计数器的方法
May 05 Python
Python实现压缩与解压gzip大文件的方法
Sep 18 Python
python数据清洗系列之字符串处理详解
Feb 12 Python
Python使用三种方法实现PCA算法
Dec 12 Python
python写入并获取剪切板内容的实例
May 31 Python
Python实现输入二叉树的先序和中序遍历,再输出后序遍历操作示例
Jul 27 Python
Python中的枚举类型示例介绍
Jan 09 Python
pyqt5实现按钮添加背景图片以及背景图片的切换方法
Jun 13 Python
Python 读写 Matlab Mat 格式数据的操作
May 19 Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
给keras层命名,并提取中间层输出值,保存到文档的实例
May 23 #Python
keras小技巧——获取某一个网络层的输出方式
May 23 #Python
keras自定义回调函数查看训练的loss和accuracy方式
May 23 #Python
Keras设定GPU使用内存大小方式(Tensorflow backend)
May 22 #Python
tensorflow使用L2 regularization正则化修正overfitting过拟合方式
May 22 #Python
Softmax函数原理及Python实现过程解析
May 22 #Python
You might like
PHP下打开phpMyAdmin出现403错误的问题解决方法
2013/05/23 PHP
浅析关于PHP位运算的简单权限设计
2013/06/30 PHP
深入解读php中关于抽象(abstract)类和抽象方法的问题分析
2014/01/03 PHP
PHP使用xmllint命令处理xml与html的方法
2014/12/15 PHP
PHP的swoole扩展安装方法详细教程
2016/05/18 PHP
浅谈PHP安全防护之Web攻击
2017/01/03 PHP
ThinkPHP3.1.2 使用cli命令行模式运行的方法
2020/04/14 PHP
基于jquery的仿百度搜索框效果代码
2011/04/11 Javascript
Blocksit插件实现瀑布流数据无限( 异步)加载
2014/06/20 Javascript
jQuery结合ajax实现动态加载文本内容
2015/05/19 Javascript
原生js实现节日时间倒计时功能
2017/01/18 Javascript
Angular 4根据组件名称动态创建出组件的方法教程
2017/11/01 Javascript
利用jquery如何从json中读取数据追加到html中
2017/12/01 jQuery
详解JavaScript栈内存与堆内存
2019/04/04 Javascript
微信小程序环境下将文件上传到OSS的方法步骤
2019/05/31 Javascript
微信小程序 导入图标实现过程详解
2019/10/11 Javascript
python实现简单socket通信的方法
2016/04/19 Python
Python处理XML格式数据的方法详解
2017/03/21 Python
Python常见字符串操作函数小结【split()、join()、strip()】
2018/02/02 Python
Django Admin实现三级联动的示例代码(省市区)
2018/06/22 Python
Python图像处理之简单画板实现方法示例
2018/08/30 Python
Python3网络爬虫中的requests高级用法详解
2019/06/18 Python
python pandas获取csv指定行 列的操作方法
2019/07/12 Python
python3 中的字符串(单引号、双引号、三引号)以及字符串与数字的运算
2019/07/18 Python
python 多进程共享全局变量之Manager()详解
2019/08/15 Python
分享一个pycharm专业版安装的永久使用方法
2019/09/24 Python
python多线程案例之多任务copy文件完整实例
2019/10/29 Python
ubuntu 18.04 安装opencv3.4.5的教程(图解)
2019/11/04 Python
PyQT5 实现快捷键复制表格数据的方法示例
2020/06/19 Python
什么是CSS3 HSLA色彩模式?HSLA模拟渐变色条
2016/04/26 HTML / CSS
台湾生鲜宅配:大口市集
2017/10/14 全球购物
澳大利亚便宜的家庭购物网站:CrazySales
2018/02/06 全球购物
PHP如何防止SQL注入
2014/05/03 面试题
国际政治个人自荐信范文
2013/11/26 职场文书
小学生勤俭节约演讲稿
2014/08/28 职场文书
领导干部作风整顿个人剖析材料
2014/10/11 职场文书