keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
零基础写python爬虫之使用Scrapy框架编写爬虫
Nov 07 Python
Python实现数通设备端口使用情况监控实例
Jul 15 Python
Python实现基本数据结构中栈的操作示例
Dec 04 Python
python更改已存在excel文件的方法
May 03 Python
Python批量发送post请求的实现代码
May 05 Python
浅析Python函数式编程
Oct 06 Python
使用PyQt4 设置TextEdit背景的方法
Jun 14 Python
Pycharm 安装 idea VIM插件的图文教程详解
Feb 21 Python
python使用turtle库绘制奥运五环
Feb 24 Python
树莓派升级python的具体步骤
Jul 05 Python
Python加载数据的5种不同方式(收藏)
Nov 13 Python
Python图像处理库PIL详细使用说明
Apr 06 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
PHP5中使用PDO连接数据库的方法
2010/08/01 PHP
PHP文件上传判断file是否己选择上传文件的方法
2014/11/10 PHP
php模拟post提交数据的方法
2015/02/12 PHP
Yii2单元测试用法示例
2016/11/12 PHP
Thinkphp 框架基础之源码获取、环境要求与目录结构分析
2020/04/27 PHP
编辑浪子版表单验证类
2007/05/12 Javascript
把Node.js程序加入服务实现随机启动
2015/06/25 Javascript
JavaScript代码实现禁止右键、禁选择、禁粘贴、禁shift、禁ctrl、禁alt
2015/11/17 Javascript
一道常被人轻视的web前端常见面试题(JS)
2016/02/15 Javascript
JavaScript的new date等日期函数在safari中遇到的坑
2016/10/24 Javascript
为JQuery EasyUI 表单组件增加焦点切换功能的方法
2017/04/13 jQuery
js前端传json后台接收‘‘被转为quot的问题解决
2020/11/12 Javascript
通过vue.extend实现消息提示弹框的方法记录
2021/01/07 Vue.js
Javascript生成器(Generator)的介绍与使用
2021/01/31 Javascript
[00:34]TI7不朽珍藏III——地穴编织者不朽展示
2017/07/15 DOTA
使用pyecharts无法import Bar的解决方案
2020/04/23 Python
python flask 多对多表查询功能
2017/06/25 Python
python3编写ThinkPHP命令执行Getshell的方法
2019/02/26 Python
python使用pandas处理excel文件转为csv文件的方法示例
2019/07/18 Python
将pytorch转成longtensor的简单方法
2020/02/18 Python
基于Django signals 信号作用及用法详解
2020/03/28 Python
python用TensorFlow做图像识别的实现
2020/04/21 Python
Python 代码调试技巧示例代码
2020/08/11 Python
python实现录音功能(可随时停止录音)
2020/10/26 Python
CSS3移动端vw+rem不依赖JS实现响应式布局的方法
2019/01/23 HTML / CSS
HTML5注册页面示例代码
2014/03/27 HTML / CSS
计算机专业毕业生推荐信
2013/11/25 职场文书
20年同学聚会感言
2014/02/03 职场文书
幼儿园优秀教师事迹
2014/02/13 职场文书
怎么写好自荐书
2014/03/02 职场文书
大学生求职信范文
2014/05/24 职场文书
金融专业求职信
2014/08/05 职场文书
学校关爱留守儿童活动方案
2014/08/27 职场文书
客房部经理岗位职责
2015/02/02 职场文书
2016年领导干部廉政承诺书
2016/03/24 职场文书
golang 生成对应的数据表struct定义操作
2021/04/28 Golang