Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
介绍Python中内置的itertools模块
Apr 29 Python
python 实时遍历日志文件
Apr 12 Python
基于Python闭包及其作用域详解
Aug 28 Python
Python实现嵌套列表去重方法示例
Dec 28 Python
python学生管理系统代码实现
Apr 05 Python
Python/ArcPy遍历指定目录中的MDB文件方法
Oct 27 Python
详解PyCharm+QTDesigner+PyUIC使用教程
Jun 13 Python
python调用并链接MATLAB脚本详解
Jul 05 Python
python实现BP神经网络回归预测模型
Aug 09 Python
Python下应用opencv 实现人脸检测功能
Oct 24 Python
Pytorch 搭建分类回归神经网络并用GPU进行加速的例子
Jan 09 Python
PYQT5 vscode联合操作qtdesigner的方法
Mar 24 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php 缓存函数代码
2008/08/27 PHP
php curl 伪造IP来源的实例代码
2012/11/01 PHP
基于PHP实现短信验证码接口(容联运通讯)
2016/09/06 PHP
PHP进程通信基础之信号量与共享内存通信
2017/02/19 PHP
php实现的统计字数函数定义与使用示例
2017/07/26 PHP
2017年最好用的9个php开发工具推荐(超好用)
2017/10/23 PHP
一文掌握PHP Xdebug 本地与远程调试(小结)
2019/04/23 PHP
IE下window.onresize 多次调用与死循环bug处理方法介绍
2013/11/12 Javascript
使用jquery中height()方法获取各种高度大全
2014/04/02 Javascript
vue2+el-menu实现路由跳转及当前项的设置方法实例
2017/11/07 Javascript
vue项目中仿element-ui弹框效果的实例代码
2019/04/22 Javascript
JS使用正则表达式提交页面验证的代码
2019/10/16 Javascript
微信小程序工具函数封装
2019/10/28 Javascript
在vue-cli3.0 中使用预处理器 (Sass/Less/Stylus) 配置全局变量操作
2020/08/10 Javascript
[45:16]完美世界DOTA2联赛PWL S3 Magma vs Phoenix 第一场 12.12
2020/12/16 DOTA
python3.5 + PyQt5 +Eric6 实现的一个计算器代码
2017/03/11 Python
Python基本socket通信控制操作示例
2019/01/30 Python
节日快乐! Python画一棵圣诞树送给你
2019/12/24 Python
tensorflow模型的save与restore,及checkpoint中读取变量方式
2020/05/26 Python
python编写一个会算账的脚本的示例代码
2020/06/02 Python
pandas数据处理之绘图的实现
2020/06/15 Python
使用python修改文件并立即写回到原始位置操作(inplace读写)
2020/06/28 Python
html5 Canvas画图教程(6)—canvas里画曲线之arcTo方法
2013/01/09 HTML / CSS
浅析HTML5 Landmark
2020/09/11 HTML / CSS
澳大利亚家庭花园和DIY工具网店:VidaXL
2019/05/03 全球购物
电子商务网站的创业计划书
2014/01/05 职场文书
秋季运动会广播稿
2014/02/22 职场文书
2014年财政工作总结
2014/12/10 职场文书
学校食品安全责任书
2015/01/29 职场文书
大学生年度个人总结
2015/02/15 职场文书
未中标通知书
2015/04/17 职场文书
2016幼儿园中班开学寄语
2015/12/03 职场文书
检举信的写法
2019/04/10 职场文书
十一月早安语录:把心放轻,人生就是一朵自在的云
2019/11/04 职场文书
Pytorch中expand()的使用(扩展某个维度)
2022/07/15 Python
Elasticsearch6.2服务器升配后的bug(避坑指南)
2022/09/23 Servers