keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取邮件地址的方法
Jul 10 Python
python搭建虚拟环境的步骤详解
Sep 27 Python
Jupyter notebook远程访问服务器的方法
May 24 Python
Python框架Flask的基本数据库操作方法分析
Jul 13 Python
pytorch索引查找 index_select的例子
Aug 18 Python
python带参数打包exe及调用方式
Dec 21 Python
完美解决pycharm导入自己写的py文件爆红问题
Feb 12 Python
django模型动态修改参数,增加 filter 字段的方式
Mar 16 Python
Python检测端口IP字符串是否合法
Jun 05 Python
sklearn线性逻辑回归和非线性逻辑回归的实现
Jun 09 Python
python机器学习创建基于规则聊天机器人过程示例详解
Nov 02 Python
分享Python异步爬取知乎热榜
Apr 12 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
PHP分页显示制作详细讲解
2008/11/19 PHP
Optimizer与Debugger兼容性问题的解决方法
2008/12/01 PHP
php发送与接收流文件的方法
2015/02/11 PHP
yii2-GridView在开发中常用的功能及技巧总结
2017/01/07 PHP
解决form中action属性后面?传递参数 获取不到的问题
2017/07/21 PHP
Laravel中使用Queue的最基本操作教程
2017/12/27 PHP
JavaScript浏览器选项卡效果
2010/08/25 Javascript
js substring从右边获取指定长度字符串(示例代码)
2013/12/23 Javascript
JQuery中dataGrid设置行的高度示例代码
2014/01/03 Javascript
JavaScript-RegExp对象只能使用一次问题解决方法
2014/06/23 Javascript
上传文件返回的json数据会被提示下载问题解决方案
2014/12/03 Javascript
jquery隔行换色效果实现方法
2015/01/15 Javascript
如何使用jQuery技术开发ios风格的页面导航菜单
2015/07/29 Javascript
CSS或者JS实现鼠标悬停显示另一元素
2016/01/22 Javascript
AngulerJS学习之按需动态加载文件
2017/02/13 Javascript
JS实现页面打印(整体、局部)
2017/08/18 Javascript
使用Vue组件实现一个简单弹窗效果
2018/04/23 Javascript
JavaScript中的回调函数实例讲解
2019/01/27 Javascript
ES6 如何改变JS内置行为的代理与反射
2019/02/11 Javascript
如何使用webpack打包一个库library的方法步骤
2019/12/18 Javascript
解决ant design vue 表格a-table二次封装,slots渲染的问题
2020/10/28 Javascript
python练习程序批量修改文件名
2014/01/16 Python
讲解Python中运算符使用时的优先级
2015/05/14 Python
python实现求最长回文子串长度
2018/01/22 Python
TensorFlow实现创建分类器
2018/02/06 Python
解决Pyinstaller 打包exe文件 取消dos窗口(黑框框)的问题
2019/06/21 Python
python 实现性别识别
2020/11/21 Python
Fashion Eyewear美国:英国线上设计师眼镜和太阳镜的零售商
2016/08/15 全球购物
抽象类和接口的区别
2012/09/19 面试题
GMP办公室主任岗位职责
2014/03/14 职场文书
工厂搬迁方案
2014/05/11 职场文书
小学生保护环境倡议书
2014/05/15 职场文书
代领学位证书毕业证书委托书
2014/09/30 职场文书
努力工作保证书
2015/02/28 职场文书
大学三好学生主要事迹范文
2015/11/03 职场文书
Python使用OpenCV和K-Means聚类对毕业照进行图像分割
2021/06/11 Python