keras 自定义loss model.add_loss的使用详解


Posted in Python onJune 22, 2020

一点见解,不断学习,欢迎指正

1、自定义loss层作为网络一层加进model,同时该loss的输出作为网络优化的目标函数

from keras.models import Model
import keras.layers as KL
import keras.backend as K
import numpy as np
from keras.utils.vis_utils import plot_model
 
x_train=np.random.normal(1,1,(100,784))
 
x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
def custom_loss1(y_true,y_pred):
 return K.mean(K.abs(y_true-y_pred))
loss1=KL.Lambda(lambda x:custom_loss1(*x),name='loss1')([x,x_in])
 
model = Model(x_in, [loss1])
model.get_layer('loss1').output#取出loss
model.add_loss(loss1)#作为网络优化的目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
#
model.fit(x_train, None, epochs=5)

2、自定义loss,作为网络优化的目标函数

x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
 
model = Model(x_in, x)
loss = K.mean((x - x_in)**2)
model.add_loss(loss)#只是作为loss优化目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
model.fit(x_train, None, epochs=5)

补充知识:keras load_weights fine-tune

分享一个小技巧,就是在构建网络模型的时候,不要怕麻烦,给每一层都定义一个名字,这样在复用之前的参数权重的时候,除了官网给的先加载权重,再冻结权重之外,你可以通过简单的修改层的名字来达到加载之前训练的权重的目的,假设权重文件保存为model_pretrain.h5 ,重新使用的时候,我把想要复用的层的名字设置成一样的,然后

model.load_weights('model_pretrain.h5', by_name=True)

以上这篇keras 自定义loss model.add_loss的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 运算符 供重载参考
Jun 11 Python
python实现网站的模拟登录
Jan 04 Python
利用Python中SocketServer 实现客户端与服务器间非阻塞通信
Dec 15 Python
Python标准库sched模块使用指南
Jul 06 Python
pandas修改DataFrame列名的方法
Apr 08 Python
opencv python 图像去噪的实现方法
Aug 31 Python
python中关于数据类型的学习笔记
Jul 19 Python
哪种Python框架适合你?简单介绍几种主流Python框架
Aug 04 Python
如何在windows下安装配置python工具Ulipad
Oct 27 Python
如何用python识别滑块验证码中的缺口
Apr 01 Python
如何用python清洗文件中的数据
Jun 18 Python
Django中celery的使用项目实例
Jul 07 Python
Python项目跨域问题解决方案
Jun 22 #Python
python os模块在系统管理中的应用
Jun 22 #Python
解决tensorflow读取本地MNITS_data失败的原因
Jun 22 #Python
python实现猜数游戏(保存游戏记录)
Jun 22 #Python
基于Tensorflow读取MNIST数据集时网络超时的解决方式
Jun 22 #Python
在Mac中配置Python虚拟环境过程解析
Jun 22 #Python
tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this T
Jun 22 #Python
You might like
小偷PHP+Html+缓存
2006/12/20 PHP
PHP 强制性文件下载功能的函数代码(任意文件格式)
2010/05/26 PHP
php array_pop()数组函数将数组最后一个单元弹出(出栈)
2011/07/12 PHP
php数据结构 算法(PHP描述) 简单选择排序 simple selection sort
2011/08/09 PHP
PHP中全局变量global和$GLOBALS[]的区别分析
2012/08/06 PHP
PHP中实现中文字串截取无乱码的解决方法
2018/05/29 PHP
JavaScript中null与undefined分析
2009/07/25 Javascript
JavaScript与DOM组合动态创建表格实例
2012/12/23 Javascript
JS验证IP,子网掩码,网关和MAC的方法
2015/07/02 Javascript
Jquery EasyUI实现treegrid上显示checkbox并取选定值的方法
2016/04/29 Javascript
JS动态生成年份和月份实例代码
2017/02/04 Javascript
jquery事件与绑定事件
2017/03/16 Javascript
用js将long型数据转换成date型或datetime型的实例
2017/07/03 Javascript
js求数组中全部数字可拼接出的最大整数示例代码
2017/08/25 Javascript
详解.vue文件中监听input输入事件(oninput)
2017/09/19 Javascript
BootStrap 标题设置跨行无效的解决方法
2017/10/25 Javascript
jquery ajaxfileupload异步上传插件
2017/11/21 jQuery
微信小程序如何调用新闻接口实现列表循环
2019/07/02 Javascript
vue-cli设置css不生效的解决方法
2020/02/07 Javascript
[55:56]NB vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.22
2019/09/05 DOTA
Python和Perl绘制中国北京跑步地图的方法
2016/03/03 Python
TensorFlow在MAC环境下的安装及环境搭建
2017/11/14 Python
把csv文件转化为数组及数组的切片方法
2018/07/04 Python
Python编程深度学习计算库之numpy
2018/12/28 Python
Python读取Excel数据并生成图表过程解析
2020/06/18 Python
微信html5页面调用第三方位置导航的示例
2018/03/14 HTML / CSS
Vans荷兰官方网站:美国南加州的原创极限运动潮牌
2018/01/23 全球购物
人事专员职责
2014/02/22 职场文书
《音乐之都维也纳》教学反思
2014/04/16 职场文书
园林专业毕业生自荐信
2014/07/04 职场文书
合伙经营协议书范本
2014/09/13 职场文书
被告代理词范文
2015/05/25 职场文书
详解Django中 render() 函数的使用方法
2021/04/22 Python
浅谈Python中的正则表达式
2021/06/28 Python
Spring Boot两种全局配置和两种注解的操作方法
2021/06/29 Java/Android
SQLServer 错误: 15404,无法获取有关 Windows NT 组/用户 WIN-8IVSNAQS8T7\Administrator 的信息
2021/06/30 SQL Server