keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python用sndhdr模块识别音频格式详解
Jan 11 Python
Python 函数返回值的示例代码
Mar 11 Python
python实现趣味图片字符化
Apr 30 Python
Python 使用PyQt5 完成选择文件或目录的对话框方法
Jun 27 Python
关于Python中的向量相加和numpy中的向量相加效率对比
Aug 26 Python
Django项目创建到启动详解(最全最详细)
Sep 07 Python
Python通过VGG16模型实现图像风格转换操作详解
Jan 16 Python
使用Python+selenium实现第一个自动化测试脚本
Mar 17 Python
使用opencv识别图像红色区域,并输出红色区域中心点坐标
Jun 02 Python
Django CBV模型源码运行流程详解
Aug 17 Python
python判断变量是否为列表的方法
Sep 17 Python
python性能测试工具locust的使用
Dec 28 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
日本十大最佳动漫,全都是二次元的神级作品
2019/10/05 日漫
php和js交互一例-PHP教程,PHP应用
2007/01/03 PHP
PHP URL参数获取方式的四种例子
2014/02/28 PHP
php封装的表单验证类完整实例
2016/10/19 PHP
php实现替换手机号中间数字为*号及隐藏IP最后几位的方法
2016/11/16 PHP
thinkphp ajaxfileupload实现异步上传图片的示例
2017/08/28 PHP
thinkphp5修改view到根目录实例方法
2019/07/02 PHP
laravel实现图片上传预览,及编辑时可更换图片,并实时变化的例子
2019/11/14 PHP
通过jQuery源码学习javascript(二)
2012/12/27 Javascript
node.js中的fs.writeFileSync方法使用说明
2014/12/14 Javascript
jQuery中get()方法用法实例
2014/12/27 Javascript
JQuery中Bind()事件用法分析
2015/05/05 Javascript
JS原型链怎么理解
2016/06/27 Javascript
html5 canvas 详细使用教程
2017/01/20 Javascript
vuejs父子组件之间数据交互详解
2017/08/09 Javascript
详解如何使用webpack在vue项目中写jsx语法
2017/11/08 Javascript
vue 项目接口管理的实现
2019/01/17 Javascript
在微信小程序中保存网络图片
2019/02/12 Javascript
python实现人人自动回复、抢沙发功能
2018/06/08 Python
Python 爬取携程所有机票的实例代码
2018/06/11 Python
Python Pillow Image Invert
2019/01/22 Python
Python3.4学习笔记之列表、数组操作示例
2019/03/01 Python
详解用python生成随机数的几种方法
2019/08/04 Python
python实现批量修改文件名
2020/03/23 Python
python使用正则表达式匹配txt特定字符串(有换行)
2020/12/09 Python
复古斯堪的纳维亚儿童服装:Baby go Retro
2017/09/09 全球购物
导游实习生自荐书
2014/01/28 职场文书
金融管理应届生求职信
2014/02/20 职场文书
地质灾害防治方案
2014/05/14 职场文书
校运动会广播稿(100篇)
2014/09/12 职场文书
小学母亲节活动总结
2015/02/10 职场文书
电气工程师岗位职责
2015/02/12 职场文书
警示教育片观后感
2015/06/17 职场文书
tensorboard 可视化之localhost:6006不显示的解决方案
2021/05/22 Python
浅谈Java实现分布式事务的三种方案
2021/06/11 Java/Android
深入讲解Vue中父子组件通信与事件触发
2022/03/22 Vue.js