keras 自定义loss层+接受输入实例


Posted in Python onJune 28, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)

首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

一、keras自定义损失函数

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一
def vae_loss(x, x_decoded_mean):
 xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
 kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
 return xent_loss + kl_loss
 
vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
 
 def __init__(self, **kwargs):
  self.is_placeholder = True
  super(CustomVariationalLayer, self).__init__(**kwargs)
 def vae_loss(self, x, x_decoded_mean_squash):
 
  x = K.flatten(x)
  x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
  xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
  kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
  return K.mean(xent_loss + kl_loss)
 
 def call(self, inputs):
 
  x = inputs[0]
  x_decoded_mean_squash = inputs[1]
  loss = self.vae_loss(x, x_decoded_mean_squash)
  self.add_loss(loss, inputs=inputs)
  # We don't use this output.
  return x
 
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

二、keras中的样本权重

# Import
import numpy as np
from sklearn.utils import class_weight
 
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
 
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
    loss='binary_crossentropy',
    metrics=['accuracy'])
 
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
           np.unique(y_train),
           y_train)
 
# Add the class weights to the training           
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)

Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].

以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python求列表交集的方法汇总
Nov 10 Python
python3音乐播放器简单实现代码
Apr 20 Python
深入理解Python中的*重复运算符
Oct 28 Python
Python实现螺旋矩阵的填充算法示例
Dec 28 Python
在CentOS6上安装Python2.7的解决方法
Jan 09 Python
将字典转换为DataFrame并进行频次统计的方法
Apr 08 Python
Python中GIL的使用详解
Oct 03 Python
对python读取zip压缩文件里面的csv数据实例详解
Feb 08 Python
python实现批量处理将图片粘贴到另一张图片上并保存
Dec 12 Python
Python内存泄漏和内存溢出的解决方案
Sep 26 Python
python利用platform模块获取系统信息
Oct 09 Python
Python的Tqdm模块实现进度条配置
Feb 24 Python
python批量处理多DNS多域名的nslookup解析实现
Jun 28 #Python
解决Keras TensorFlow 混编中 trainable=False设置无效问题
Jun 28 #Python
Keras: model实现固定部分layer,训练部分layer操作
Jun 28 #Python
sklearn的predict_proba使用说明
Jun 28 #Python
基于python实现ROC曲线绘制广场解析
Jun 28 #Python
Python sklearn中的.fit与.predict的用法说明
Jun 28 #Python
浅谈sklearn中predict与predict_proba区别
Jun 28 #Python
You might like
php xml实例 留言本
2009/03/20 PHP
php实现websocket实时消息推送
2018/03/30 PHP
PHP编程一定要改掉的5个不良习惯
2020/09/18 PHP
一个JS小玩意 几个属性相加不能超过一个特定值.
2009/09/29 Javascript
基于JQuery实现的Select级联
2014/01/27 Javascript
Jquery倒计时源码分享
2014/05/16 Javascript
jQuery实现textarea自动增长宽高的方法
2015/12/18 Javascript
探讨:JavaScript ECAMScript5 新特性之get/set访问器
2016/05/05 Javascript
基于JQuery实现分隔条的功能
2016/06/17 Javascript
jQuery 获取遍历获取table中每一个tr中的第一个td的方法
2016/10/05 Javascript
关于Jquery中的bind(),on()绑定事件方式总结
2016/10/26 Javascript
canvas实现简易的圆环进度条效果
2017/02/28 Javascript
bootstrap table表格使用方法详解
2017/04/26 Javascript
VUE 使用中踩过的坑
2018/02/08 Javascript
layer弹出层 iframe层去掉滚动条的实例代码
2018/08/17 Javascript
Vue.set 全局操作简单示例
2019/09/19 Javascript
js实现特别简单的钟表效果
2020/09/14 Javascript
Python3实现的腾讯微博自动发帖小工具
2013/11/11 Python
Python中在脚本中引用其他文件函数的实现方法
2016/06/23 Python
深入浅析ImageMagick命令执行漏洞
2016/10/11 Python
python与sqlite3实现解密chrome cookie实例代码
2018/01/20 Python
python生成圆形图片的方法
2020/03/25 Python
PyQt5每天必学之关闭窗口
2018/04/19 Python
python 每天如何定时启动爬虫任务(实现方法分享)
2018/05/21 Python
python实现五子棋人机对战游戏
2020/03/25 Python
Tensorflow实现神经网络拟合线性回归
2019/07/19 Python
Python Multiprocessing多进程 使用tqdm显示进度条的实现
2019/08/13 Python
Tensorflow累加的实现案例
2020/02/05 Python
CSS3中background-clip和background-origin的区别示例介绍
2014/03/10 HTML / CSS
html5 迷宫游戏(碰撞检测)实例一
2013/07/25 HTML / CSS
伦敦高级内衣品牌:Agent Provocateur(大内密探)
2016/08/23 全球购物
宣传普通话标语
2014/06/27 职场文书
认错检讨书
2014/10/02 职场文书
个人党性锻炼总结
2015/03/05 职场文书
详解JS WebSocket断开原因和心跳机制
2021/05/07 Javascript
试用1103暨1103、1101同门大比武 [ DAIWEI ]
2022/04/05 无线电