keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取远程文件大小的函数代码分享
May 13 Python
Django文件存储 默认存储系统解析
Aug 02 Python
Python paramiko模块使用解析(实现ssh)
Aug 30 Python
Python2比较当前图片跟图库哪个图片相似的方法示例
Sep 28 Python
Pycharm 2019 破解激活方法图文详解
Oct 11 Python
如何基于python实现归一化处理
Jan 20 Python
Django单元测试中Fixtures的使用方法
Feb 26 Python
pyinstaller打包找不到文件的问题解决
Apr 15 Python
解决paramiko执行命令超时的问题
Apr 16 Python
Python xlrd/xlwt 创建excel文件及常用操作
Sep 24 Python
Python爬虫之爬取某文库文档数据
Apr 21 Python
利用Python实现模拟登录知乎
May 25 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
初学者入门:细述PHP4的核心Zend
2006/09/05 PHP
PHP 数组排序方法总结 推荐收藏
2010/06/30 PHP
php输出表格的实现代码(修正版)
2010/12/29 PHP
php实现mysql封装类示例
2014/05/07 PHP
PHP数组函数知识汇总
2016/05/12 PHP
PHP中的use关键字及文件的加载详解
2016/11/28 PHP
php+jQuery实现的三级导航栏下拉菜单显示效果
2017/08/10 PHP
PHP自动生成缩略图函数的源码示例
2019/03/18 PHP
基于Jquery的仿Windows Aero弹出窗(漂亮的关闭按钮)
2010/09/28 Javascript
Jquery 跨域访问 Lightswitch OData Service的方法
2013/09/11 Javascript
jquery.form.js实现将form提交转为ajax方式提交的方法
2015/04/07 Javascript
jQuery实现的网页右下角tab样式在线客服效果代码
2015/10/23 Javascript
Jquery 效果使用详解
2015/11/23 Javascript
JavaScript 轮播图和自定义滚动条配合鼠标滚轮分享代码贴
2016/10/28 Javascript
JavaScript事件方法(实例讲解)
2017/06/27 Javascript
vue封装第三方插件并发布到npm的方法
2017/09/25 Javascript
angularJS实现动态添加,删除div方法
2018/02/27 Javascript
vue devtools的安装与使用教程
2018/08/08 Javascript
浅谈vue限制文本框输入数字的正确姿势
2019/09/02 Javascript
高效jQuery选择器的5个技巧实例分析
2019/11/26 jQuery
python操作MySQL数据库具体方法
2013/10/28 Python
用Python编程实现语音控制电脑
2014/04/01 Python
Python PyQt5标准对话框用法示例
2017/08/23 Python
Python单链表原理与实现方法详解
2020/02/22 Python
html5+CSS3+JS实现七夕言情功能代码
2017/08/28 HTML / CSS
Web时代变迁及html5与html4的区别
2016/01/06 HTML / CSS
公务员年总结的自我评价
2013/10/25 职场文书
生产部主管岗位职责
2014/01/06 职场文书
社区学习雷锋活动总结
2014/04/25 职场文书
销售个人求职信范文
2014/04/28 职场文书
医学生求职信
2014/07/01 职场文书
党的群众路线教育实践活动个人对照检查材料(校长)
2014/11/05 职场文书
2014年招商引资工作总结
2014/11/22 职场文书
不尊敬老师的检讨书
2014/12/21 职场文书
2019年教师节:送给所有老师的祝福语
2019/09/05 职场文书
Python实现视频自动打码的示例代码
2022/04/08 Python