Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python设计模式之抽象工厂模式
Aug 25 Python
Python多线程扫描端口代码示例
Feb 09 Python
python使用sqlite3时游标使用方法
Mar 13 Python
Python 实现删除某路径下文件及文件夹的实例讲解
Apr 24 Python
python实现定时发送qq消息
Jan 18 Python
python getpass模块用法及实例详解
Oct 07 Python
Python使用指定字符长度切分数据示例
Dec 05 Python
python Manager 之dict KeyError问题的解决
Dec 21 Python
解决python3插入mysql时内容带有引号的问题
Mar 02 Python
Python错误的处理方法
Jun 23 Python
使用Python+Appuim 清理微信的方法
Jan 26 Python
Python 正则模块详情
Nov 02 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
一个目录遍历函数
2006/10/09 PHP
在线增减.htpasswd内的用户
2006/10/09 PHP
PHP实现的Redis多库选择功能单例类
2017/07/27 PHP
Javascript公共脚本库系列(一): 弹出层脚本
2011/02/24 Javascript
关于event.cancelBubble和event.stopPropagation()的区别介绍
2011/12/11 Javascript
jquery 滚动条事件简单实例
2013/07/12 Javascript
使用js写的一个简易的投票
2013/11/27 Javascript
基于Vue.js的表格分页组件
2016/05/22 Javascript
node.js 动态执行脚本
2016/06/02 Javascript
微信小程序 页面跳转传参详解
2016/10/28 Javascript
微信小程序 ES6Promise.all批量上传文件实现代码
2017/04/14 Javascript
node.js中事件触发器events的使用方法实例分析
2019/11/23 Javascript
JS实现点击掉落特效
2021/01/29 Javascript
使用Python进行稳定可靠的文件操作详解
2013/12/31 Python
Python编写登陆接口的方法
2017/07/10 Python
pyqt5的QComboBox 使用模板的具体方法
2018/09/06 Python
Python线程下使用锁的技巧分享
2018/09/13 Python
Python随机生成身份证号码及校验功能
2018/12/04 Python
python根据url地址下载小文件的实例
2018/12/18 Python
Python和Java的语法对比分析语法简洁上python的确完美胜出
2019/05/10 Python
Python+redis通过限流保护高并发系统
2020/04/15 Python
HTTP状态码详解
2021/03/18 杂记
Lands’ End官网:经典的美国生活方式品牌
2016/08/14 全球购物
英国景点门票网站:attractiontix
2019/08/27 全球购物
企业厂长岗位职责
2013/12/17 职场文书
中学运动会广播稿
2014/01/19 职场文书
素质拓展感言
2014/01/29 职场文书
揭牌仪式策划方案
2014/05/28 职场文书
最美孝心少年事迹材料
2014/08/15 职场文书
校长创先争优承诺书
2014/08/30 职场文书
小学生通知书评语
2014/12/31 职场文书
离职感谢信
2015/01/21 职场文书
2016年学校安全教育月活动总结
2016/04/06 职场文书
标准版个人借条怎么写?以及什么是借条?
2019/08/28 职场文书
Django 实现jwt认证的示例
2021/04/30 Python
Python实现视频中添加音频工具详解
2021/12/06 Python