Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django实现的自定义访问日志模块示例
Jun 23 Python
Django与JS交互的示例代码
Aug 23 Python
Android基于TCP和URL协议的网络编程示例【附demo源码下载】
Jan 23 Python
Laravel+Dingo/Api 自定义响应的实现
Feb 17 Python
Python 异常的捕获、异常的传递与主动抛出异常操作示例
Sep 23 Python
Python中如何将一个类方法变为多个方法
Dec 30 Python
tensorflow 实现打印pb模型的所有节点
Jan 23 Python
python查找特定名称文件并按序号、文件名分行打印输出的方法
Apr 24 Python
python继承threading.Thread实现有返回值的子类实例
May 02 Python
Python基于execjs运行js过程解析
Nov 27 Python
一文带你了解Python 四种常见基础爬虫方法介绍
Dec 04 Python
如何判断pytorch是否支持GPU加速
Jun 01 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
中英文字符串翻转函数
2008/12/09 PHP
php 智能404跳转代码,适合换域名没改变目录的网站
2010/06/04 PHP
PHP+MySQL实现的简单投票系统实例
2016/02/24 PHP
Yii2 rbac权限控制之菜单menu实例教程
2016/04/28 PHP
php实现大文件断点续传下载实例代码
2019/10/01 PHP
Javascript 同时提交多个Web表单的方法
2009/02/19 Javascript
select、radio表单回显功能实现避免使用jquery载入赋值
2013/06/08 Javascript
HTML长文本截取含有HTML代码同样适用的两种方法
2013/07/31 Javascript
一个JavaScript的求爱小特效
2014/05/09 Javascript
javascript动态添加checkbox复选框的方法
2015/12/23 Javascript
JavaScript中循环遍历Array与Map的方法小结
2016/03/12 Javascript
使用jQuery操作DOM的方法小结
2017/02/27 Javascript
深入对Vue.js $watch方法的理解
2017/03/20 Javascript
详解node HTTP请求客户端 - Request
2017/05/05 Javascript
Vue声明式渲染详解
2017/05/17 Javascript
微信小程序实现滑动删除效果
2017/05/19 Javascript
vue slot 在子组件中显示父组件传递的模板
2018/03/02 Javascript
vue移动端UI框架实现QQ侧边菜单组件
2018/03/09 Javascript
Vue实现美团app的影院推荐选座功能【推荐】
2018/08/29 Javascript
详解html-webpack-plugin插件(用法总结)
2018/09/12 Javascript
js使用swiper实现层叠轮播效果实例代码
2018/12/12 Javascript
js实现计时器秒表功能
2019/12/16 Javascript
JS面向对象编程——ES6 中class的继承用法详解
2020/03/03 Javascript
vue使用openlayers实现移动点动画
2020/09/24 Javascript
python实现淘宝秒杀脚本
2020/06/23 Python
python实现图片转字符小工具
2019/04/30 Python
解决python tkinter界面卡死的问题
2019/07/17 Python
什么是Python中的顺序表
2020/06/02 Python
Python替换NumPy数组中大于某个值的所有元素实例
2020/06/08 Python
对python pandas中 inplace 参数的理解
2020/06/27 Python
小白教你PyCharm从下载到安装再到科学使用PyCharm2020最新激活码
2020/09/25 Python
电子信息工程专业推荐信
2014/02/14 职场文书
公证委托书大全
2014/04/04 职场文书
公司年终奖分配方案
2014/06/16 职场文书
离婚协议书的范本
2015/01/27 职场文书
Python办公自动化之教你如何用Python将任意文件转为PDF格式
2021/06/28 Python