Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列表(list)、字典(dict)、字符串(string)基本操作小结
Nov 28 Python
在Python的Flask框架中使用日期和时间的教程
Apr 21 Python
Python 爬虫模拟登陆知乎
Sep 23 Python
浅谈flask截获所有访问及before/after_request修饰器
Jan 18 Python
Python实现中一次读取多个值的方法
Apr 22 Python
Python3之简单搭建自带服务器的实例讲解
Jun 04 Python
python3射线法判断点是否在多边形内
Jun 28 Python
Python Gitlab Api 使用方法
Aug 28 Python
Python paramiko模块使用解析(实现ssh)
Aug 30 Python
python中web框架的自定义创建
Sep 08 Python
Python如何实现小程序 无限求和平均
Feb 18 Python
tensorflow基于CNN实战mnist手写识别(小白必看)
Jul 20 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
PHP setcookie设置Cookie用法(及设置无效的问题)
2011/07/13 PHP
微博短链接算法php版本实现代码
2012/09/15 PHP
2014年最新推荐的10款 PHP 开发框架
2014/08/01 PHP
php中substr()函数参数说明及用法实例
2014/11/15 PHP
php内存缓存实现方法
2015/01/24 PHP
详解Yii2 rules 的验证规则
2016/12/02 PHP
php+iframe 实现上传文件功能示例
2020/03/04 PHP
iframe 异步加载技术及性能分析
2011/07/19 Javascript
jquery为页面增加快捷键示例
2014/01/31 Javascript
JavaScript中几种排序算法的简单实现
2015/07/29 Javascript
js倒计时显示实例
2016/12/11 Javascript
微信小程序学习(4)-系统配置app.json详解
2017/01/12 Javascript
Node.js中 __dirname 的使用介绍
2017/06/19 Javascript
js实现方块上下左右移动效果
2017/08/17 Javascript
React教程之Props验证的具体用法(Props Validation)
2017/09/04 Javascript
vue实现按钮切换图片
2021/01/20 Vue.js
[02:27]DOTA2英雄基础教程 莱恩
2014/01/17 DOTA
python实现list元素按关键字相加减的方法示例
2017/06/09 Python
Python实现注册、登录小程序功能
2018/09/21 Python
python 分离文件名和路径以及分离文件名和后缀的方法
2018/10/21 Python
python里 super类的工作原理详解
2019/06/19 Python
python 随机森林算法及其优化详解
2019/07/11 Python
python中通过selenium简单操作及元素定位知识点总结
2019/09/10 Python
Python:二维列表下标互换方式(矩阵转置)
2019/12/02 Python
Python使用qrcode二维码库生成二维码方法详解
2020/02/17 Python
Python爬虫代理池搭建的方法步骤
2020/09/28 Python
python中的yield from语法快速学习
2020/11/06 Python
CSS3的颜色渐变效果的示例代码
2017/09/29 HTML / CSS
资深地理教师自我评价
2013/09/21 职场文书
重阳节登山活动方案
2014/02/03 职场文书
大学生个人实习的自我评价
2014/02/15 职场文书
2015年校本培训工作总结
2015/07/24 职场文书
二十年同学聚会感言
2015/07/30 职场文书
Python利用capstone实现反汇编
2022/04/06 Python
Flutter集成高德地图并添加自定义Maker的实践
2022/04/07 Java/Android
Mysql的Table doesn't exist问题及解决
2022/12/24 MySQL