Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Tkinter实现简易计算器功能
Jan 30 Python
浅谈Python的list中的选取范围
Nov 12 Python
深入浅析python3中的unicode和bytes问题
Jul 03 Python
Django forms表单 select下拉框的传值实例
Jul 19 Python
FFrpc python客户端lib使用解析
Aug 24 Python
解决pandas展示数据输出时列名不能对齐的问题
Nov 18 Python
Pycharm使用远程linux服务器conda/python环境在本地运行的方法(图解))
Dec 09 Python
python3 实现口罩抽签的功能
Mar 11 Python
linux 下selenium chrome使用详解
Apr 02 Python
Jupyter安装链接aconda实现过程图解
Nov 02 Python
20行代码教你用python给证件照换底色的方法示例
Feb 05 Python
用Python监控你的朋友都在浏览哪些网站?
May 27 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php防止伪造的数据从URL提交方法
2014/06/27 PHP
php上传图片类及用法示例
2016/05/11 PHP
在JavaScript中使用inline函数的问题
2007/03/08 Javascript
js实现双向链表互联网机顶盒实战应用实现
2011/10/28 Javascript
多次注册事件会导致一个事件被触发多次的解决方法
2013/08/12 Javascript
js 绑定键盘鼠标事件示例代码
2014/02/12 Javascript
微信小程序基于slider组件动态修改标签透明度的方法示例
2017/12/04 Javascript
页面点击小红心js实现代码
2018/05/26 Javascript
vue中的watch监听数据变化及watch中各属性的详解
2018/09/11 Javascript
微信小程序中weui用法解析
2019/10/21 Javascript
ES6 Iterator遍历器原理,应用场景及相关常用知识拓展详解
2020/02/15 Javascript
如何实现echarts markline标签名显示自己想要的
2020/07/20 Javascript
封装Vue Element的table表格组件的示例详解
2020/08/19 Javascript
[56:01]2018DOTA2亚洲邀请赛 3.31 小组赛 B组 Effect vs EG
2018/03/31 DOTA
[51:20]完美世界DOTA2联赛PWL S2 Magma vs PXG 第一场 11.28
2020/12/01 DOTA
python爬虫教程之爬取百度贴吧并下载的示例
2014/03/07 Python
centos系统升级python 2.7.3
2014/07/03 Python
python代码实现ID3决策树算法
2017/12/20 Python
详解用Python处理HTML转义字符的5种方式
2017/12/27 Python
python 定时修改数据库的示例代码
2018/04/08 Python
python web自制框架之接受url传递过来的参数实例
2018/12/17 Python
详解Python中的GIL(全局解释器锁)详解及解决GIL的几种方案
2021/01/29 Python
HTML5拖放API实现自动生成相框功能
2020/04/07 HTML / CSS
canvas小画板之平滑曲线的实现
2020/08/12 HTML / CSS
萌新的HTML5 入门指南
2020/11/06 HTML / CSS
阿玛尼化妆品美国官网:Giorgio Armani Beauty
2017/02/02 全球购物
台湾团购、宅配和优惠券:17Life
2017/08/14 全球购物
关于逃课的检讨书
2014/01/23 职场文书
公务员平时考核实施方案
2014/03/11 职场文书
聚美优品励志广告词
2014/03/14 职场文书
《大禹治水》教学反思
2014/04/27 职场文书
怀孕辞职信怎么写
2015/02/28 职场文书
企业法人代表证明书
2015/06/18 职场文书
给校长的建议书作文300字
2015/09/14 职场文书
《一面五星红旗》教学反思
2016/02/23 职场文书
Python爬虫进阶之Beautiful Soup库详解
2021/04/29 Python