Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python计算程序运行时间的方法
Dec 13 Python
对pandas的dataframe绘图并保存的实现方法
Aug 05 Python
caffe binaryproto 与 npy相互转换的实例讲解
Jul 09 Python
Python之使用adb shell命令启动应用的方法详解
Jan 07 Python
python中bs4.BeautifulSoup的基本用法
Jul 27 Python
Python CVXOPT模块安装及使用解析
Aug 01 Python
Python字典推导式将cookie字符串转化为字典解析
Aug 10 Python
pygame实现俄罗斯方块游戏(基础篇2)
Oct 29 Python
如何使用Python发送HTML格式的邮件
Feb 11 Python
Python unittest工作原理和使用过程解析
Feb 24 Python
python GUI库图形界面开发之PyQt5浏览器控件QWebEngineView详细使用方法
Feb 26 Python
如何使用 Flask 做一个评论系统
Nov 27 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php入门教程 精简版
2009/12/13 PHP
解析php addslashes()与addclashes()函数的区别和比较
2013/06/24 PHP
Yii2中hasOne、hasMany及多对多关联查询的用法详解
2017/02/15 PHP
一些常用弹出窗口/拖放/异步文件上传等实用代码
2013/01/06 Javascript
js中精确计算加法和减法示例
2014/03/28 Javascript
jquery 操作css样式、位置、尺寸方法汇总
2014/11/28 Javascript
Javascript动画的实现原理浅析
2015/03/02 Javascript
JS简单实现城市二级联动选择插件的方法
2015/08/19 Javascript
Bootstrap布局组件应用实例讲解
2016/02/17 Javascript
js实现人民币大写金额形式转换
2016/04/27 Javascript
Angularjs中的页面访问权限怎么设置
2016/11/11 Javascript
AngularJS实现DOM元素的显示与隐藏功能
2016/11/22 Javascript
浅谈使用React.setState需要注意的三点
2017/12/18 Javascript
python中cPickle用法例子分享
2014/01/03 Python
python使用beautifulsoup从爱奇艺网抓取视频播放
2014/01/23 Python
基于python的字节编译详解
2017/09/20 Python
Python实现计算圆周率π的值到任意位的方法示例
2018/05/08 Python
python简单区块链模拟详解
2019/07/03 Python
Python线程障碍对象Barrier原理详解
2019/12/02 Python
kafka-python 获取topic lag值方式
2019/12/23 Python
Python调用百度OCR实现图片文字识别的示例代码
2020/07/17 Python
在PyCharm中安装PaddlePaddle的方法
2021/02/05 Python
HTML5 canvas 瀑布流文字效果的示例代码
2018/01/31 HTML / CSS
苹果美国官方商城:Apple美国
2016/08/24 全球购物
Toppik顶丰增发纤维官网:解决头发稀疏
2017/12/30 全球购物
Vero Moda西班牙官方购物网站:丹麦BESTSELLER旗下知名女装品牌
2018/04/27 全球购物
《听鱼说话》教学反思
2014/02/15 职场文书
大学应届生的自我评价
2014/03/06 职场文书
活动总结怎么写
2014/04/28 职场文书
测绘工程专业求职信
2014/07/15 职场文书
2014最新毕业证代领委托书
2014/09/26 职场文书
2014年向国旗敬礼活动总结
2014/09/27 职场文书
职称评定个人总结
2015/03/05 职场文书
培训学校2015年度工作总结
2015/07/20 职场文书
导游词之嵊泗列岛
2019/10/30 职场文书
Python天气语音播报小助手
2021/09/25 Python