Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
多版本Python共存的配置方法
May 22 Python
python生成二维码的实例详解
Oct 29 Python
使用python实现BLAST
Feb 12 Python
基于pandas数据样本行列选取的方法
Apr 20 Python
Pandas读取并修改excel的示例代码
Feb 17 Python
set在python里的含义和用法
Jun 24 Python
python中update的基本使用方法详解
Jul 17 Python
浅谈Python3 numpy.ptp()最大值与最小值的差
Aug 24 Python
python爬虫开发之使用python爬虫库requests,urllib与今日头条搜索功能爬取搜索内容实例
Mar 10 Python
Python开发之身份证验证库id_validator验证身份证号合法性及根据身份证号返回住址年龄等信息
Mar 20 Python
opencv 图像轮廓的实现示例
Jul 08 Python
python实现经典排序算法的示例代码
Feb 07 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php运行出现Call to undefined function curl_init()的解决方法
2010/11/02 PHP
php利用腾讯ip分享计划获取地理位置示例分享
2014/01/20 PHP
PHP面向对象多态性实现方法简单示例
2017/09/27 PHP
PHP实现二叉树深度优先遍历(前序、中序、后序)和广度优先遍历(层次)实例详解
2018/04/20 PHP
PHP实现的函数重载功能示例
2018/08/03 PHP
二级域名转向类
2006/11/09 Javascript
拖拉表格的JS函数
2008/11/20 Javascript
表单元素事件 (Form Element Events)
2009/07/17 Javascript
基于Jquery的文字滚动跑马灯插件(一个页面多个滚动区)
2010/07/26 Javascript
JavaScript中圆括号()和方括号[]的特殊用法疑问解答
2013/08/06 Javascript
javascript实现的图片切割多块效果实例
2015/05/07 Javascript
深入探讨javascript函数式编程
2015/10/11 Javascript
详解基于angular路由的requireJs按需加载js
2017/01/20 Javascript
javascript设计模式之中介者模式学习笔记
2017/02/15 Javascript
js/jq仿window文件夹框选操作插件
2017/03/08 Javascript
详解Vue2+Echarts实现多种图表数据可视化Dashboard(附源码)
2017/03/21 Javascript
详解react使用react-bootstrap当轮子造车
2017/08/15 Javascript
详解微信小程序实现跑马灯效果(附完整代码)
2019/04/29 Javascript
vue项目中使用scss的方法步骤
2019/05/16 Javascript
微信小程序进入广告实现代码实例
2019/09/19 Javascript
python去除字符串中的换行符
2017/10/11 Python
python通过百度地图API获取某地址的经纬度详解
2018/01/28 Python
python实现QQ邮箱/163邮箱的邮件发送
2019/01/22 Python
Django框架ORM数据库操作实例详解
2019/11/07 Python
python3注册全局热键的实现
2020/03/22 Python
pandas分组聚合详解
2020/04/10 Python
解决python Jupyter不能导入外部包问题
2020/04/15 Python
关于Python不换行输出和不换行输出end=““不显示的问题(亲测已解决)
2020/10/27 Python
Python字典实现伪切片功能
2020/10/28 Python
Python tkinter之Bind(绑定事件)的使用示例
2021/02/05 Python
欧洲品牌瓷器餐具网上商店:Porzellantreff.de
2018/04/04 全球购物
Fresh馥蕾诗英国官网:法国LVMH集团旗下高端天然护肤品牌
2018/11/01 全球购物
香港彩色隐形眼镜在线商店:Stunninglens(全球免费送货)
2019/05/10 全球购物
工程地质勘察专业大学生求职信
2013/10/13 职场文书
表彰先进集体通报
2014/01/12 职场文书
领导干部培训感言
2014/01/23 职场文书