keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基础入门之seed()方法的使用
May 15 Python
详解Django缓存处理中Vary头部的使用
Jul 24 Python
Python基础教程之浅拷贝和深拷贝实例详解
Jul 15 Python
TensorFLow用Saver保存和恢复变量
Mar 10 Python
python微信公众号开发简单流程
Mar 23 Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 Python
python实现贪吃蛇小游戏
Mar 21 Python
Python写一个基于MD5的文件监听程序
Mar 11 Python
python的set处理二维数组转一维数组的方法示例
May 31 Python
python使用Pandas库提升项目的运行速度过程详解
Jul 12 Python
pytorch点乘与叉乘示例讲解
Dec 27 Python
基于python 凸包问题的解决
Apr 16 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
JAVA/JSP学习系列之六
2006/10/09 PHP
在IIS上安装PHP4.0正式版
2006/10/09 PHP
PHP+Ajax实现的无刷新分页功能详解【附demo源码下载】
2017/07/03 PHP
javascript document.compatMode兼容性
2010/02/23 Javascript
JavaScript 学习初步 入门教程
2010/03/25 Javascript
JsDom 编程小结
2011/08/09 Javascript
Jquery读取URL参数小例子
2013/08/30 Javascript
鼠标拖动实现DIV排序示例代码
2013/10/14 Javascript
jquery text(),val(),html()方法区别总结
2013/11/04 Javascript
JS短路原理的应用示例 精简代码的途径
2013/12/13 Javascript
jQuery中next()方法用法实例
2015/01/07 Javascript
Javascript控制input输入时间格式的方法
2015/01/28 Javascript
Node.js中的缓冲与流模块详细介绍
2015/02/11 Javascript
json定义及jquery操作json的方法
2016/09/29 Javascript
JS实现选定指定HTML元素对象中指定文本内容功能示例
2017/02/13 Javascript
JavaScript实现的超简单计算器功能示例
2017/12/23 Javascript
如何从零开始手写Koa2框架
2019/03/22 Javascript
JS多个异步请求 按顺序执行next实现解析
2019/09/16 Javascript
JavaScript隐式类型转换代码实例
2020/05/29 Javascript
[02:41]2015国际邀请赛中国区预选赛观战指南
2015/05/20 DOTA
Python利用matplotlib生成图片背景及图例透明的效果
2017/04/27 Python
python+django加载静态网页模板解析
2017/12/12 Python
Python建立Map写Excel表实例解析
2018/01/17 Python
Python pandas常用函数详解
2018/02/07 Python
python修改txt文件中的某一项方法
2018/12/29 Python
详解Python的循环结构知识点
2019/05/20 Python
Python多进程编程multiprocessing代码实例
2020/03/12 Python
Jabra捷波朗美国官网:用于办公、车载和运动的无线蓝牙耳麦
2017/02/01 全球购物
科级干部群众路线教育实践活动对照检查材料思想汇报
2014/09/20 职场文书
小学教师工作总结2015
2015/04/07 职场文书
大学生见习总结报告
2015/06/24 职场文书
新闻稿格式范文
2015/07/18 职场文书
治理商业贿赂工作总结
2015/08/10 职场文书
2019年学校消防安全责任书(2篇)
2019/10/09 职场文书
mybatis3中@SelectProvider传递参数方式
2021/08/04 Java/Android
vue实现列表垂直无缝滚动
2022/04/08 Vue.js