keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
介绍Python的Django框架中的静态资源管理器django-pipeline
Apr 25 Python
在Python中操作字典之setdefault()方法的使用
May 21 Python
深入理解python中的浅拷贝和深拷贝
May 30 Python
python使用pandas实现数据分割实例代码
Jan 25 Python
python numpy 显示图像阵列的实例
Jul 02 Python
对tf.reduce_sum tensorflow维度上的操作详解
Jul 26 Python
Python 3.3实现计算两个日期间隔秒数/天数的方法示例
Jan 07 Python
Python根据欧拉角求旋转矩阵的实例
Jan 28 Python
Python3+OpenCV2实现图像的几何变换(平移、镜像、缩放、旋转、仿射)
May 13 Python
Python内置方法实现字符串的秘钥加解密(推荐)
Dec 09 Python
浅谈Python访问MySQL的正确姿势
Jan 07 Python
解决Python3.8用pip安装turtle-0.0.2出现错误问题
Feb 11 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
PHP网站备份程序代码分享
2011/06/10 PHP
php数据结构 算法(PHP描述) 简单选择排序 simple selection sort
2011/08/09 PHP
小议Function.apply()之二------利用Apply的参数数组化来提高 JavaScript程序性能
2006/11/30 Javascript
JavaScript 中的事件教程
2007/04/05 Javascript
Javascript+CSS实现影像卷帘效果思路及代码
2014/10/20 Javascript
高性能JavaScript模板引擎实现原理详解
2015/02/05 Javascript
jQuery插件slides实现无缝轮播图特效
2015/04/17 Javascript
使用基于Node.js的构建工具Grunt来发布ASP.NET MVC项目
2016/02/15 Javascript
BootStrap的弹出框(Popover)支持鼠标移到弹出层上弹窗层不隐藏的原因及解决办法
2016/04/03 Javascript
jQuery输入框密码的显示隐藏【代码分享】
2017/04/29 jQuery
vue实现样式之间的切换及vue动态样式的实现方法
2017/12/19 Javascript
node打造微信个人号机器人的方法示例
2018/04/26 Javascript
jQuery实现侧边栏隐藏与显示的方法详解
2018/12/22 jQuery
原生js基于canvas实现一个简单的前端截图工具代码实例
2019/09/10 Javascript
详解vuejs中执行npm run dev出现页面cannot GET/问题
2020/04/26 Javascript
vue移动端下拉刷新和上滑加载
2020/10/27 Javascript
JS前端基于canvas给图片添加水印
2020/11/11 Javascript
[51:53]完美世界DOTA2联赛循环赛 LBZS vs DM BO2第二场 11.01
2020/11/02 DOTA
Python实现115网盘自动下载的方法
2014/09/30 Python
python文件写入实例分析
2015/04/08 Python
在Python中marshal对象序列化的相关知识
2015/07/01 Python
python下调用pytesseract识别某网站验证码的实现方法
2016/06/06 Python
Python+matplotlib+numpy实现在不同平面的二维条形图
2018/01/02 Python
详细解读tornado协程(coroutine)原理
2018/01/15 Python
python调用API实现智能回复机器人
2018/04/10 Python
详解Python中的四种队列
2018/05/21 Python
Django使用Mysql数据库已经存在的数据表方法
2018/05/27 Python
python下载卫星云图合成gif的方法示例
2020/02/18 Python
在pycharm创建scrapy项目的实现步骤
2020/12/01 Python
瑞士隐形眼镜和护理产品网上商店:Linsenklick
2019/10/21 全球购物
德国的大型美妆个护电商:Flaconi
2020/06/26 全球购物
建筑文秘专业个人求职信范文
2013/12/28 职场文书
车间核算员岗位职责
2014/07/01 职场文书
具结保证书范本
2015/05/11 职场文书
地道战观后感2000字
2015/06/04 职场文书
2016党员学习《反对自由主义》心得体会
2016/01/22 职场文书