Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python字符串替换实例分析
May 11 Python
Python脚本处理空格的方法
Aug 08 Python
使用Python和xlwt向Excel文件中写入中文的实例
Apr 21 Python
python邮件发送smtplib使用详解
Jun 16 Python
Python高阶函数、常用内置函数用法实例分析
Dec 26 Python
python3 Scrapy爬虫框架ip代理配置的方法
Jan 17 Python
opencv python Canny边缘提取实现过程解析
Feb 03 Python
浅谈ROC曲线的最佳阈值如何选取
Feb 28 Python
Eclipse配置python默认头过程图解
Apr 26 Python
python3用PyPDF2解析pdf文件,用正则匹配数据方式
May 12 Python
记一次Django响应超慢的解决过程
Sep 17 Python
python unittest单元测试的步骤分析
Aug 02 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
PHP header()函数常用方法总结
2014/04/11 PHP
PHP静态文件生成类实例
2014/11/29 PHP
php结合ACCESS的跨库查询功能
2015/06/12 PHP
PHPWind9.0手动屏蔽验证码解决后台关闭验证码但是依然显示的问题
2016/08/12 PHP
javascript:void(0)的真正含义实例分析
2008/08/20 Javascript
用 Javascript 验证表单(form)中的单选(radio)值
2009/09/08 Javascript
判断iframe是否加载完成的完美方法
2010/01/07 Javascript
JavaScript 类型的包装对象(Typed Wrappers)
2011/10/27 Javascript
jquery.autocomplete修改实现键盘上下键自动填充示例
2013/11/19 Javascript
javascript对下拉列表框(select)的操作实例讲解
2013/11/29 Javascript
JavaScript设计模式之外观模式介绍
2014/12/28 Javascript
用js判断是否为360浏览器的实现代码
2015/01/15 Javascript
JS+CSS实现简单滑动门(滑动菜单)效果
2015/09/19 Javascript
JavaScript中数组添加值和访问值常见问题
2016/02/06 Javascript
js返回顶部实例分享
2016/12/21 Javascript
微信小程序 网络请求(post请求,get请求)
2017/01/17 Javascript
js 判断登录界面的账号密码是否为空
2017/02/08 Javascript
vue2.0与bootstrap3实现列表分页效果
2017/11/28 Javascript
JS实现监控微信小程序的原理
2018/06/15 Javascript
vue+高德地图写地图选址组件的方法
2019/05/18 Javascript
[03:16]DOTA2完美大师赛主赛事首日集锦
2017/11/23 DOTA
[01:21:36]CHAOS vs Alliacne 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
python读取word文档,插入mysql数据库的示例代码
2018/11/07 Python
python3中property使用方法详解
2019/04/23 Python
pycharm中显示CSS提示的知识点总结
2019/07/29 Python
python 爬取学信网登录页面的例子
2019/08/13 Python
详解Python 最短匹配模式
2020/07/29 Python
详解基于Scrapy的IP代理池搭建
2020/09/29 Python
CSS3 3D酷炫立方体变换动画的实现
2019/03/26 HTML / CSS
校园之星获奖感言
2014/01/29 职场文书
路政管理求职信
2014/06/18 职场文书
高中生国庆节演讲稿范文2014
2014/09/21 职场文书
2019年个人工作总结范文
2019/03/25 职场文书
导游词之四川武侯祠
2019/10/21 职场文书
在NumPy中深拷贝和浅拷贝相关操作的定义和背后的原理
2022/04/14 Python
SpringBoot全局异常处理方案分享
2022/05/25 Java/Android