keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
介绍Python中几个常用的类方法
Apr 08 Python
Python实现周期性抓取网页内容的方法
Nov 04 Python
python实现自主查询实时天气
Jun 22 Python
Python实现的连接mssql数据库操作示例
Aug 17 Python
在Pycharm中对代码进行注释和缩进的方法详解
Jan 20 Python
python安装scipy的方法步骤
Jun 26 Python
对python 树状嵌套结构的实现思路详解
Aug 09 Python
Python爬取腾讯视频评论的思路详解
Dec 19 Python
Python装饰器用法与知识点小结
Mar 09 Python
解决keras GAN训练是loss不发生变化,accuracy一直为0.5的问题
Jul 02 Python
Python3.8安装Pygame教程步骤详解
Aug 14 Python
Python的三个重要函数详解
Jan 18 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
WordPress中获取所使用的模板的页面ID的简单方法
2015/12/31 PHP
JavaScript中的私有/静态属性介绍
2012/07/26 Javascript
Flexigrid在IE下不显示数据的有效处理方法
2014/09/04 Javascript
JS实现跟随鼠标闪烁转动色块的方法
2015/02/26 Javascript
jQuery实现根据类型自动显示和隐藏表单
2015/03/18 Javascript
浅谈setTimeout 与 setInterval
2015/06/23 Javascript
学习JavaScript设计模式(多态)
2015/11/25 Javascript
Javascript之面向对象--封装
2016/12/02 Javascript
JS使用面向对象技术实现的tab选项卡效果示例
2017/02/28 Javascript
jquery实现放大镜简洁代码(推荐)
2017/06/08 jQuery
Bootstrap模态框插入视频的实现代码
2017/06/25 Javascript
微信小程序分页加载的实例代码
2017/07/11 Javascript
vue-cli开发时,关于ajax跨域的解决方法(推荐)
2018/02/03 Javascript
Vue实现搜索 和新闻列表功能简单范例
2018/03/16 Javascript
vue通过style或者class改变样式的实例代码
2018/10/30 Javascript
vue中使用rem布局代码详解
2019/10/30 Javascript
js如何验证密码强度
2020/03/18 Javascript
纯js+css实现在线时钟
2020/08/18 Javascript
[48:21]林俊杰圣堂刺客超神杀戮秀
2014/10/29 DOTA
python使用os模块的os.walk遍历文件夹示例
2014/01/27 Python
如何在sae中设置django,让sae的工作环境跟本地python环境一致
2017/11/21 Python
对python for 文件指定行读写操作详解
2018/12/29 Python
Flask配置Cors跨域的实现
2019/07/12 Python
Python实现性能自动化测试竟然如此简单
2019/07/30 Python
pymysql模块的操作实例
2019/12/17 Python
家庭户外服装:Hawkshead
2017/11/02 全球购物
导游的职业规划书范文
2013/12/27 职场文书
偷看我的初中毕业鉴定
2014/01/29 职场文书
幼儿园大班教学反思
2014/02/10 职场文书
应届大专毕业生自我鉴定
2014/04/08 职场文书
2014大学生中国梦主题教育学习思想汇报
2014/09/10 职场文书
夫妻婚内购房协议书
2014/10/05 职场文书
群众路线自查自纠工作情况报告
2014/10/28 职场文书
2015年检验员工作总结范文
2015/04/30 职场文书
小学生大队委竞选稿
2015/11/20 职场文书
Golang 正则匹配效率详解
2021/04/25 Golang