Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python类装饰器用法实例
Jun 04 Python
django 实现电子支付功能的示例代码
Jul 25 Python
python实现简单tftp(基于udp协议)
Jul 30 Python
python实现字符串中字符分类及个数统计
Sep 28 Python
Pandas GroupBy对象 索引与迭代方法
Nov 16 Python
Django实现一对多表模型的跨表查询方法
Dec 18 Python
使用python和pygame制作挡板弹球游戏
Dec 03 Python
python实现监控阿里云账户余额功能
Dec 16 Python
关于Python Tkinter Button控件command传参问题的解决方式
Mar 04 Python
python+playwright微软自动化工具的使用
Feb 02 Python
Python将QQ聊天记录生成词云的示例代码
Feb 10 Python
浅谈Python numpy创建空数组的问题
May 25 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
绿山咖啡和蓝山咖啡
2021/03/04 新手入门
php中显示数组与对象的实现代码
2011/04/18 PHP
PHP实现模仿socket请求返回页面的方法
2014/11/04 PHP
一个非常完美的读写ini格式的PHP配置类分享
2015/02/12 PHP
客户端静态页面玩分页
2006/06/26 Javascript
js实现拖拽 闭包函数详细介绍
2012/11/25 Javascript
深入分析jsonp协议原理
2015/09/26 Javascript
jQuery回调方法使用示例
2017/06/26 jQuery
浅谈ES6 模板字符串的具体使用方法
2017/11/07 Javascript
微信小程序上传图片实例
2018/05/28 Javascript
JS中call和apply函数用法实例分析
2018/06/20 Javascript
浅析vue.js数组的变异方法
2018/06/30 Javascript
vue封装swiper代码实例解析
2019/10/08 Javascript
jquery将json转为数据字典的实例代码
2019/10/11 jQuery
[38:44]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第二局
2016/02/25 DOTA
[38:39]完美世界DOTA2联赛循环赛 IO vs GXR BO2第二场 11.04
2020/11/05 DOTA
Python中使用Tkinter模块创建GUI程序实例
2015/01/14 Python
Python中实现三目运算的方法
2015/06/21 Python
Python算法输出1-9数组形成的结果为100的所有运算式
2017/11/03 Python
python自动截取需要区域,进行图像识别的方法
2018/05/17 Python
Python基于机器学习方法实现的电影推荐系统实例详解
2019/06/25 Python
CSS3 RGBA色彩模式使用实例讲解
2016/04/26 HTML / CSS
西班牙用户之间买卖视频游戏的平台:Wakkap
2020/03/21 全球购物
Java面试题:Java类的Main方法如果是Private将会怎么样
2016/08/18 面试题
餐厅经理岗位职责范本
2014/02/17 职场文书
社区禁毒工作方案
2014/06/02 职场文书
个人违纪检讨书
2014/09/15 职场文书
教师群众路线剖析材料
2014/09/29 职场文书
大学生联谊活动策划书(光棍节)
2014/10/10 职场文书
乡镇三严三实学习心得体会
2014/10/13 职场文书
2015年食堂工作总结报告
2015/04/23 职场文书
2015年科普工作总结
2015/07/23 职场文书
患者身份识别制度
2015/08/06 职场文书
Pytorch中Softmax和LogSoftmax的使用详解
2021/06/05 Python
分析JVM源码之Thread.interrupt系统级别线程打断
2021/06/29 Java/Android
电频谱管理的原则是什么
2022/02/18 无线电