Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中lambda与def用法对比实例分析
Apr 30 Python
Python 包含汉字的文件读写之每行末尾加上特定字符
Dec 12 Python
利用Python爬虫给孩子起个好名字
Feb 14 Python
Python 调用Java实例详解
Jun 02 Python
Python 使用PIL numpy 实现拼接图片的示例
May 08 Python
对Python中9种生成新对象的方法总结
May 23 Python
python TKinter获取文本框内容的方法
Oct 11 Python
pygame游戏之旅 创建游戏窗口界面
Nov 20 Python
python用match()函数爬数据方法详解
Jul 23 Python
如何配置关联Python 解释器 Anaconda的教程(图解)
Apr 30 Python
python如何保存文本文件
Jun 07 Python
基于Django快速集成Echarts代码示例
Dec 01 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
用PHP提取中英文词语以及数字的首字母的方法介绍
2013/04/23 PHP
php把session写入数据库示例
2014/02/26 PHP
PHP在网页中动态生成PDF文件详细教程
2014/07/05 PHP
PHP实现的字符串匹配算法示例【sunday算法】
2017/12/19 PHP
php使用array_chunk函数将一个数组分割成多个数组
2018/12/05 PHP
PHP实现批量修改文件名的方法示例
2019/09/18 PHP
js 日期转换成中文格式的函数
2009/07/07 Javascript
jquery 触发a链接点击事件解决方案
2013/05/02 Javascript
jQuery 快速结束当前正在执行的动画
2013/11/20 Javascript
javascript中call,apply,bind的用法对比分析
2015/02/12 Javascript
浅谈javascript中this在事件中的应用
2015/02/15 Javascript
AngularJS学习笔记之ng-options指令
2015/06/16 Javascript
在Python中使用glob模块查找文件路径的方法
2015/06/17 Javascript
浅谈javascript中的DOM方法
2015/07/16 Javascript
jQuery实现类似老虎机滚动抽奖效果
2015/08/06 Javascript
JQuery标签页效果实例详解
2015/12/24 Javascript
JavaScript实现拖拽元素对齐到网格(每次移动固定距离)
2016/11/30 Javascript
jQuery插件FusionCharts实现的Marimekko图效果示例【附demo源码】
2017/03/24 jQuery
vue通过cookie获取用户登录信息的思路详解
2018/10/30 Javascript
javascript中可能用得到的全部的排序算法
2020/03/05 Javascript
Python导入模块时遇到的错误分析
2017/08/30 Python
python中从str中提取元素到list以及将list转换为str的方法
2018/06/26 Python
Django的models模型的具体使用
2019/07/15 Python
Django Rest framework认证组件详细用法
2019/07/25 Python
复化梯形求积分实例——用Python进行数值计算
2019/11/20 Python
用Python画小女孩放风筝的示例
2019/11/23 Python
python如何求数组连续最大和的示例代码
2020/02/04 Python
Larsson & Jennings官网:现代瑞士钟表匠
2018/03/20 全球购物
运动会通讯稿400字
2014/01/28 职场文书
优秀老师事迹材料
2014/02/05 职场文书
销售类求职信
2014/06/13 职场文书
民主生活会意见
2015/06/05 职场文书
学生会2016感恩节活动小结
2016/04/01 职场文书
如何做好工作总结!
2019/04/10 职场文书
SQL IDENTITY_INSERT作用案例详解
2021/08/23 MySQL
vue中控制mock在开发环境使用,在生产环境禁用方式
2022/04/06 Vue.js