Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 面向对象 成员的访问约束
Dec 23 Python
python 转换 Javascript %u 字符串为python unicode的代码
Sep 06 Python
python使用opencv进行人脸识别
Apr 07 Python
详解Python3操作Mongodb简明易懂教程
May 25 Python
python下实现二叉堆以及堆排序的示例
Sep 29 Python
对变量赋值的理解--Pyton中让两个值互换的实现方法
Nov 29 Python
Python内置模块hashlib、hmac与uuid用法分析
Feb 12 Python
Python实现的redis分布式锁功能示例
May 29 Python
Python将文字转成语音并读出来的实例详解
Jul 15 Python
pytorch 批次遍历数据集打印数据的例子
Dec 30 Python
Python PyQt5模块实现窗口GUI界面代码实例
May 12 Python
Python2手动安装更新pip过程实例解析
Jul 16 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
ThinkPHP和UCenter接口冲突的解决方法
2016/07/25 PHP
PHP 模拟登陆功能实例详解
2019/09/10 PHP
js获取RadioButtonList的Value/Text及选中值等信息实现代码
2013/03/05 Javascript
jquery ui resize 中border-box的bug修正
2015/04/26 Javascript
iPhone手机上搭建nodejs服务器步骤方法
2015/07/06 NodeJs
谈谈JavaScript异步函数发展历程
2015/09/29 Javascript
浅谈jquery之on()绑定事件和off()解除绑定事件
2016/10/26 Javascript
JavaScript ES6中CLASS的使用详解
2016/11/22 Javascript
解决ajax不能访问本地文件问题(利用js跨域原理)
2017/01/24 Javascript
vue2.0 elementUI制作面包屑导航栏
2018/02/22 Javascript
基于 Immutable.js 实现撤销重做功能的实例代码
2018/03/01 Javascript
快速搭建vue2.0+boostrap项目的方法
2018/04/09 Javascript
vue2.0学习之axios的封装与vuex介绍
2018/05/28 Javascript
vue项目部署上线遇到的问题及解决方法
2018/06/10 Javascript
Vue源码解析之数组变异的实现
2018/12/04 Javascript
vue 引用自定义ttf、otf、在线字体的方法
2019/05/09 Javascript
JavaScript 面向对象程序设计详解【类的创建、实例对象、构造函数、原型等】
2020/05/12 Javascript
vue组件系列之TagsInput详解
2020/05/14 Javascript
[51:22]Fnatic vs IG 2018国际邀请赛小组赛BO2 第一场 8.17
2018/08/18 DOTA
玩转python爬虫之正则表达式
2016/02/17 Python
Python元字符的用法实例解析
2018/01/17 Python
Tornado Web Server框架编写简易Python服务器
2018/07/28 Python
对python_discover方法遍历所有执行的用例详解
2019/02/13 Python
python实现计数排序与桶排序实例代码
2019/03/28 Python
Python实现的登录验证系统完整案例【基于搭建的MVC框架】
2019/04/12 Python
信号生成及DFT的python实现方式
2020/02/25 Python
CSS3 filter(滤镜)实现网页灰色或者黑色模式的代码
2020/11/30 HTML / CSS
DVF官方网站:美国时装界尊尚品牌
2017/08/29 全球购物
超市营业员求职简历的自我评价
2013/10/17 职场文书
工伤赔偿协议书范本
2014/04/15 职场文书
2015年元旦活动总结
2014/05/09 职场文书
2014和解协议书范文
2014/09/15 职场文书
伏羲庙导游词
2015/02/09 职场文书
幼儿园毕业致辞
2015/07/29 职场文书
2019个人年度目标制定攻略!
2019/07/12 职场文书
阿里云 Windows server 2019 配置FTP
2022/04/28 Servers