Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之二叉树的遍历实例
Apr 29 Python
玩转python爬虫之URLError异常处理
Feb 17 Python
selenium+python自动化测试之鼠标和键盘事件
Jan 23 Python
python使用thrift教程的方法示例
Mar 21 Python
python读取并写入mat文件的方法
Jul 12 Python
运用PyTorch动手搭建一个共享单车预测器
Aug 06 Python
python数据类型之间怎么转换技巧分享
Aug 20 Python
python pygame实现挡板弹球游戏
Nov 25 Python
Flask和pyecharts实现动态数据可视化
Feb 26 Python
Python中如何引入第三方模块
May 27 Python
Python使用Matlab命令过程解析
Jun 04 Python
Django创建一个后台的基本步骤记录
Oct 02 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
PHP调用三种数据库的方法(3)
2006/10/09 PHP
php数组函数序列之rsort() - 对数组的元素值进行降序排序
2011/11/02 PHP
php判断上传的Excel文件中是否有图片及PHPExcel库认识
2013/01/11 PHP
php递归删除目录下的文件但保留的实例分享
2014/05/10 PHP
thinkphp表单上传文件并将文件路径保存到数据库中
2016/07/28 PHP
PHP常用函数之base64图片上传功能详解
2019/10/21 PHP
获取body标签的两种方法
2011/10/13 Javascript
js获取通过ajax返回的map型的JSONArray的方法
2014/01/09 Javascript
利用JQuery写一个简单的异步分页插件
2016/03/07 Javascript
非常漂亮的相册集 使用jquery制作相册集
2016/04/28 Javascript
Node.js中常规的文件操作总结
2016/10/13 Javascript
Angularjs实现搜索关键字高亮显示效果
2017/01/17 Javascript
详解Vue取消eslint语法限制
2018/08/04 Javascript
微信小程序中为什么使用var that=this
2019/08/27 Javascript
Node登录权限验证token验证实现的方法示例
2020/05/25 Javascript
如何解决jQuery 和其他JS库的冲突
2020/06/22 jQuery
Vue父子组件传值的一些坑
2020/09/16 Javascript
[49:56]VG vs Optic 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
在Python的Django框架中加载模版的方法
2015/07/16 Python
python利用rsa库做公钥解密的方法教程
2017/12/10 Python
Python实现元素等待代码实例
2019/11/11 Python
Python中如何引入第三方模块
2020/05/27 Python
通过cmd进入python的步骤
2020/06/16 Python
Ubuntu 20.04安装Pycharm2020.2及锁定到任务栏的问题(小白级操作)
2020/10/29 Python
怀俄明州飞钓:Platte River Fly Shop
2017/12/28 全球购物
Hunkemöller瑞士网上商店:欧洲最大的内衣品牌之一
2018/12/03 全球购物
Java里面Pass by value和Pass by Reference是什么意思
2016/05/02 面试题
大学毕业生文采飞扬的自我鉴定
2013/12/03 职场文书
品质口号大全
2014/06/17 职场文书
房屋分割离婚协议书范本
2014/12/01 职场文书
2015年世界无烟日活动总结
2015/02/10 职场文书
2015年七一建党节演讲稿
2015/03/19 职场文书
开学第一周日记(三篇范文)
2019/08/23 职场文书
《蓝鲸的眼睛》读后感5篇
2020/01/15 职场文书
JVM上高性能数据格式库包Apache Arrow入门和架构详解(Gkatziouras)
2021/05/26 Servers
总结三种用 Python 作为小程序后端的方式
2022/05/02 Python