Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
零基础写python爬虫之urllib2中的两个重要概念:Openers和Handlers
Nov 05 Python
在Python中使用成员运算符的示例
May 13 Python
深入解析Python中的集合类型操作符
Aug 19 Python
Python基于select实现的socket服务器
Apr 13 Python
Python网络爬虫项目:内容提取器的定义
Oct 25 Python
基于Python的关键字监控及告警
Jul 06 Python
python中实现精确的浮点数运算详解
Nov 02 Python
利用Pyhton中的requests包进行网页访问测试的方法
Dec 26 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 Python
python实现UDP协议下的文件传输
Mar 20 Python
Java如何基于wsimport调用wcf接口
Jun 17 Python
详解Python 循环嵌套
Jul 09 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php获取操作系统语言代码
2013/11/04 PHP
详解PHP中的Traits
2015/07/29 PHP
教你识别简单的免查杀PHP后门
2015/09/13 PHP
PHP使用file_get_content设置头信息的方法
2016/02/14 PHP
PHP创建文件及写入数据(覆盖写入,追加写入)的方法详解
2019/02/15 PHP
Cookie 注入是怎样产生的
2009/04/08 Javascript
Javascript 中文字符串处理额外注意事项
2009/11/15 Javascript
JavaScript中也使用$美元符号来代替document.getElementById
2010/06/19 Javascript
js代码实现的加入收藏效果并兼容主流浏览器
2014/06/23 Javascript
jquery实现一个简单好用的弹出框
2014/09/26 Javascript
jQuery插件EasyUI校验规则 validatebox验证框
2015/11/29 Javascript
原生js实现移动端瀑布流式代码示例
2015/12/18 Javascript
WEB 前端开发中防治重复提交的实现方法
2016/10/26 Javascript
vue.js中Vue-router 2.0基础实践教程
2017/05/08 Javascript
Node.js使用Angular简单示例
2018/05/11 Javascript
JavaScript实现的级联算法示例【省市二级联动功能】
2018/12/25 Javascript
微信小程序视图控件与bindtap之间的问题的解决
2019/04/08 Javascript
Vue 用Vant实现时间选择器的示例代码
2019/10/25 Javascript
vue实现计步器功能
2019/11/01 Javascript
JS正则表达式验证密码强度
2020/03/18 Javascript
VSCode搭建React Native环境
2020/05/07 Javascript
python中正则表达式 re.findall 用法
2018/10/23 Python
python获取指定日期范围内的每一天,每个月,每季度的方法
2019/08/08 Python
django 自定义过滤器(filter)处理较为复杂的变量方法
2019/08/12 Python
pytorch实现特殊的Module--Sqeuential三种写法
2020/01/15 Python
python利用百度云接口实现车牌识别的示例
2020/02/21 Python
python爬虫开发之selenium模块详细使用方法与实例全解
2020/03/09 Python
Django查询优化及ajax编码格式原理解析
2020/03/25 Python
利用python3筛选excel中特定的行(行值满足某个条件/行值属于某个集合)
2020/09/04 Python
Linux常见面试题
2016/10/04 面试题
倡议书的格式写法
2015/04/28 职场文书
民事申诉状范本
2015/05/20 职场文书
党支部半年考察意见
2015/06/01 职场文书
nginx 反向代理之 proxy_pass的实现
2021/03/31 Servers
TensorFlow中tf.batch_matmul()的用法
2021/06/02 Python
JavaScript前端面试组合函数
2022/06/21 Javascript