keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python类和函数中使用静态变量的方法
May 09 Python
详解用python实现简单的遗传算法
Jan 02 Python
Python实用技巧之利用元组代替字典并为元组元素命名
Jul 11 Python
设置python3为默认python的方法
Oct 31 Python
Python读取Pickle文件信息并计算与当前时间间隔的方法分析
Jan 30 Python
Python 依赖库太多了该如何管理
Nov 08 Python
详解Python在使用JSON时需要注意的编码问题
Dec 06 Python
Python猴子补丁知识点总结
Jan 05 Python
django model object序列化实例
Mar 13 Python
python Django 反向访问器的外键冲突解决
May 20 Python
Python 列表反转显示的四种方法
Nov 16 Python
Python Http请求json解析库用法解析
Nov 28 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
PHP脚本数据库功能详解(上)
2006/10/09 PHP
PHP中防止SQL注入实现代码
2011/02/19 PHP
php使用mkdir创建多级目录入门例子
2014/05/10 PHP
合并ThinkPHP配置文件以消除代码冗余的实现方法
2014/07/22 PHP
PHP函数eval()介绍和使用示例
2014/08/20 PHP
php压缩文件夹最新版
2018/07/18 PHP
PHP PDOStatement::rowCount讲解
2019/02/01 PHP
PHP获取ttf格式文件字体名的方法示例
2019/03/06 PHP
JavaScript随机排序(随即出牌)
2010/09/17 Javascript
学习面向对象之面向对象的基本概念:对象和其他基本要素
2010/11/30 Javascript
JS事件在IE与FF中的区别详细解析
2013/11/20 Javascript
原生javascript实现的ajax异步封装功能示例
2016/11/03 Javascript
JavaScript实现二维坐标点排序效果
2017/07/18 Javascript
element-ui 限制日期选择的方法(datepicker)
2018/05/16 Javascript
Vue 获取数组键名的方法
2018/06/21 Javascript
vue@cli3项目模板怎么使用public目录下的静态文件
2020/07/07 Javascript
Python基于socket模块实现UDP通信功能示例
2018/04/10 Python
python:接口间数据传递与调用方法
2018/12/17 Python
详解如何用python实现一个简单下载器的服务端和客户端
2019/10/28 Python
浅析Python数字类型和字符串类型的内置方法
2019/12/22 Python
Python中常用的高阶函数实例详解
2020/02/21 Python
如何在Python对Excel进行读取
2020/06/04 Python
纯HTML5+CSS3制作图片旋转
2016/01/12 HTML / CSS
Jo Malone美国官网:祖玛珑香水
2017/03/27 全球购物
优秀毕业生求职信范文
2014/01/02 职场文书
保安的辞职报告怎么写
2014/01/20 职场文书
护士辞职信模板
2014/01/20 职场文书
医院检讨书范文
2014/02/01 职场文书
入党申请自荐书范文
2014/02/11 职场文书
yy司仪主持词
2014/03/22 职场文书
授权委托书样本
2014/04/03 职场文书
公证书样本
2014/04/10 职场文书
海洋科学专业求职信
2014/08/10 职场文书
2014年社区计生工作总结
2014/11/18 职场文书
运动会致辞稿
2015/07/29 职场文书
Mysql表数据比较大情况下修改添加字段的方法实例
2022/06/28 MySQL