Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Pycharm编辑器技巧之自动导入模块详解
Jul 18 Python
利用python求解物理学中的双弹簧质能系统详解
Sep 29 Python
Python实现的KMeans聚类算法实例分析
Dec 29 Python
Python3实现的判断回文链表算法示例
Mar 08 Python
python使用BeautifulSoup与正则表达式爬取时光网不同地区top100电影并对比
Apr 15 Python
python实现批量转换图片为黑白
Jun 16 Python
基于Python 的语音重采样函数解析
Jul 06 Python
Python 解析简单的XML数据
Jul 24 Python
如何基于pandas读取csv后合并两个股票
Sep 25 Python
python 使用tkinter+you-get实现视频下载器
Nov 17 Python
详解pandas中利用DataFrame对象的.loc[]、.iloc[]方法抽取数据
Dec 13 Python
Python+腾讯云服务器实现每日自动健康打卡
Dec 06 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php通过隐藏表单控件获取到前两个页面的url
2014/09/09 PHP
Yii数据库缓存实例分析
2016/03/29 PHP
ThinkPHP框架表单验证操作方法
2017/07/19 PHP
Laravel 实现添加多语言提示信息
2019/10/25 PHP
PHP中用Trait封装单例模式的实现
2019/12/18 PHP
laravel框架学习笔记之组件化开发实现方法
2020/02/01 PHP
javascript 学习笔记(六)浏览器类型及版本信息检测代码
2011/04/08 Javascript
40个新鲜出炉的jQuery 插件和免费教程[上]
2012/07/24 Javascript
上传的js验证(图片/文件的扩展名)
2013/04/25 Javascript
Javascript Web Slider 焦点图示例源码
2013/10/10 Javascript
浮动的div自适应居中显示的js代码
2013/12/23 Javascript
javascript正则表达式基础知识入门
2015/04/20 Javascript
js实现左侧网页tab滑动门效果代码
2015/09/06 Javascript
js调用百度地图及调用百度地图的搜索功能
2015/09/07 Javascript
Bootstrap布局之栅格系统详解
2016/06/13 Javascript
微信小程序  modal弹框组件详解
2016/10/27 Javascript
利用js+css+html实现固定table的列头不动
2016/12/08 Javascript
es6 super关键字的理解与应用实例分析
2020/02/15 Javascript
[01:20]PWL S2开团时刻第三期——团战可以输 蝙蝠必须死
2020/11/26 DOTA
python查看zip包中文件及大小的方法
2015/07/09 Python
Python压缩解压缩zip文件及破解zip文件密码的方法
2015/11/04 Python
浅谈python和C语言混编的几种方式(推荐)
2017/09/27 Python
matplotlib设置legend图例代码示例
2017/12/19 Python
Django框架实现逆向解析url的方法
2018/07/04 Python
python Jupyter运行时间实例过程解析
2019/12/13 Python
Python log模块logging记录打印用法解析
2020/01/20 Python
详解python内置常用高阶函数(列出了5个常用的)
2020/02/21 Python
Python3如何使用tabulate打印数据
2020/09/25 Python
支持IE8的纯css3开发的响应式设计动画菜单教程
2014/11/05 HTML / CSS
CSS3+js实现简单的时钟特效
2015/03/18 HTML / CSS
EM Cosmetics官网:由彩妆大神Michelle Phan创办的独立品牌
2020/04/27 全球购物
畜牧兽医本科生的自我评价
2014/03/03 职场文书
园林设计专业毕业生求职信
2014/03/23 职场文书
公务员年度考核评语
2014/12/31 职场文书
演讲稿:态度决定一切
2019/04/02 职场文书
vue动态绑定style样式
2022/04/20 Vue.js