keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 从远程服务器下载东西的代码
Feb 10 Python
使用Python的web.py框架实现类似Django的ORM查询的教程
May 02 Python
在Python中操作文件之seek()方法的使用教程
May 24 Python
Python开发如何在ubuntu 15.10 上配置vim
Jan 25 Python
Python cookbook(数据结构与算法)找出序列中出现次数最多的元素算法示例
Mar 15 Python
python 自动去除空行的实例
Jul 24 Python
python学习--使用QQ邮箱发送邮件代码实例
Apr 16 Python
python日志模块logbook使用方法
Sep 19 Python
使用matlab或python将txt文件转为excel表格
Nov 01 Python
python如何判断IP地址合法性
Apr 05 Python
pyCharm 实现关闭代码检查
Jun 09 Python
Python2手动安装更新pip过程实例解析
Jul 16 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
PHP 多维数组排序(usort,uasort)
2010/06/30 PHP
PHP Class&Object -- 解析PHP实现二叉树
2013/06/25 PHP
PHP实现的简单缓存类
2015/07/29 PHP
php设计模式之享元模式分析【星际争霸游戏案例】
2020/03/23 PHP
js类型检查实现代码
2010/10/29 Javascript
Javascript 构造函数详解
2014/10/22 Javascript
使用jQuery给input标签设置默认值
2016/06/20 Javascript
Vue-Router模式和钩子的用法
2018/02/28 Javascript
Vue 使用中的小技巧
2018/04/26 Javascript
Bootstrap table中toolbar新增条件查询及refresh参数使用方法
2018/05/18 Javascript
讲解vue-router之命名路由和命名视图
2018/05/28 Javascript
Vue 中对图片地址进行拼接的方法
2018/09/03 Javascript
关于js陀螺仪的理解分析
2019/04/11 Javascript
json解析大全 双引号、键值对不在一起的情况
2019/12/06 Javascript
vuex 多模块时 模块内部的mutation和action的调用方式
2020/07/24 Javascript
vue 二维码长按保存和复制内容操作
2020/09/22 Javascript
Nuxt.js nuxt-link与router-link的区别说明
2020/11/06 Javascript
零基础写python爬虫之HTTP异常处理
2014/11/05 Python
Python获取文件所在目录和文件名的方法
2017/01/12 Python
详解python中的线程
2018/02/10 Python
Python运维之获取系统CPU信息的实现方法
2018/06/11 Python
在Mac上删除自己安装的Python方法
2018/10/29 Python
Python Web版语音合成实例详解
2019/07/16 Python
使用Python给头像戴上圣诞帽的图像操作过程解析
2019/09/20 Python
Python操作word文档插入图片和表格的实例演示
2020/10/25 Python
Html5 Canvas 实现一个“刮刮乐”游戏
2019/09/05 HTML / CSS
HTML5页面直接调用百度地图API获取当前位置直接导航目的地的实现代码
2018/03/02 HTML / CSS
小学毕业演讲稿
2014/04/25 职场文书
活动总结格式
2014/08/30 职场文书
2014院党委领导班子对照检查材料思想汇报
2014/09/24 职场文书
三八节活动简报
2015/07/20 职场文书
汶川大地震感悟
2015/08/10 职场文书
PyTorch 如何设置随机数种子使结果可复现
2021/05/12 Python
python入门学习关于for else的特殊特性讲解
2021/11/20 Python
图片批量处理 - 尺寸、格式、水印等
2022/03/07 杂记
Windows server 2012搭建FTP服务器
2022/04/29 Servers