keras 自定义loss model.add_loss的使用详解


Posted in Python onJune 22, 2020

一点见解,不断学习,欢迎指正

1、自定义loss层作为网络一层加进model,同时该loss的输出作为网络优化的目标函数

from keras.models import Model
import keras.layers as KL
import keras.backend as K
import numpy as np
from keras.utils.vis_utils import plot_model
 
x_train=np.random.normal(1,1,(100,784))
 
x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
def custom_loss1(y_true,y_pred):
 return K.mean(K.abs(y_true-y_pred))
loss1=KL.Lambda(lambda x:custom_loss1(*x),name='loss1')([x,x_in])
 
model = Model(x_in, [loss1])
model.get_layer('loss1').output#取出loss
model.add_loss(loss1)#作为网络优化的目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
#
model.fit(x_train, None, epochs=5)

2、自定义loss,作为网络优化的目标函数

x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
 
model = Model(x_in, x)
loss = K.mean((x - x_in)**2)
model.add_loss(loss)#只是作为loss优化目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
model.fit(x_train, None, epochs=5)

补充知识:keras load_weights fine-tune

分享一个小技巧,就是在构建网络模型的时候,不要怕麻烦,给每一层都定义一个名字,这样在复用之前的参数权重的时候,除了官网给的先加载权重,再冻结权重之外,你可以通过简单的修改层的名字来达到加载之前训练的权重的目的,假设权重文件保存为model_pretrain.h5 ,重新使用的时候,我把想要复用的层的名字设置成一样的,然后

model.load_weights('model_pretrain.h5', by_name=True)

以上这篇keras 自定义loss model.add_loss的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3基础之条件与循环控制实例解析
Aug 13 Python
跟老齐学Python之编写类之四再论继承
Oct 11 Python
Python中实现常量(Const)功能
Jan 28 Python
学习python中matplotlib绘图设置坐标轴刻度、文本
Feb 07 Python
pandas 数据实现行间计算的方法
Jun 08 Python
python 统计数组中元素出现次数并进行排序的实例
Jul 02 Python
详解python实现数据归一化处理的方式:(0,1)标准化
Jul 17 Python
django重新生成数据库中的某张表方法
Aug 28 Python
利用Python产生加密表和解密表的实现方法
Oct 15 Python
基于python计算滚动方差(标准差)talib和pd.rolling函数差异详解
Jun 08 Python
详解pytorch tensor和ndarray转换相关总结
Sep 03 Python
python中Matplotlib绘制直线的实例代码
Jul 04 Python
Python项目跨域问题解决方案
Jun 22 #Python
python os模块在系统管理中的应用
Jun 22 #Python
解决tensorflow读取本地MNITS_data失败的原因
Jun 22 #Python
python实现猜数游戏(保存游戏记录)
Jun 22 #Python
基于Tensorflow读取MNIST数据集时网络超时的解决方式
Jun 22 #Python
在Mac中配置Python虚拟环境过程解析
Jun 22 #Python
tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this T
Jun 22 #Python
You might like
php标签云的实现代码
2012/10/10 PHP
简单实用的.net DataTable导出Execl
2013/10/28 PHP
PHP封装的HttpClient类用法实例
2015/06/17 PHP
Laravel使用Queue队列的技巧汇总
2019/09/02 PHP
在chrome中window.onload事件的一些问题
2010/03/01 Javascript
event.X和event.clientX的区别分析
2011/10/06 Javascript
jquery事件机制扩展插件 jquery鼠标右键事件。
2011/12/26 Javascript
JS检测输入字符是否包含非法字符的示例代码
2014/02/11 Javascript
JavaScript基本数据类型及值类型和引用类型
2015/08/25 Javascript
如何防止JavaScript自动插入分号
2015/11/05 Javascript
详解js中Number()、parseInt()和parseFloat()的区别
2016/12/20 Javascript
layui导航栏实现代码
2017/05/19 Javascript
node.js 发布订阅模式的实例
2017/09/10 Javascript
微信小程序实现列表滚动头部吸顶的示例代码
2020/07/12 Javascript
解决Echarts2竖直datazoom滑动后显示数据不全的问题
2020/07/20 Javascript
我所理解的JavaScript中的this指向
2020/09/04 Javascript
[03:16]DOTA2完美大师赛主赛事首日集锦
2017/11/23 DOTA
[50:54]完美世界DOTA2联赛 GXR vs IO 第三场 11.07
2020/11/10 DOTA
python实现web方式logview的方法
2015/08/10 Python
简单谈谈python中的多进程
2016/11/06 Python
PyQt5每天必学之组合框
2018/04/20 Python
python如何发布自已pip项目的方法步骤
2018/10/09 Python
Pandas 按索引合并数据集的方法
2018/11/15 Python
Python 把序列转换为元组的函数tuple方法
2019/06/27 Python
50行Python代码获取高考志愿信息的实现方法
2019/07/23 Python
Python实现自动访问网页的例子
2020/02/21 Python
基于Python把网站域名解析成ip地址
2020/05/25 Python
python如何爬取网页中的文字
2020/07/28 Python
Python实现自动签到脚本功能
2020/08/20 Python
非常震撼的纯CSS3人物行走动画
2016/02/24 HTML / CSS
世界上最好的儿童品牌:AlexandAlexa
2018/01/27 全球购物
业务部门经理岗位职责
2014/02/23 职场文书
职员竞岗演讲稿
2014/05/14 职场文书
《比的意义》教学反思
2016/02/18 职场文书
Python中如何处理常见报错
2022/01/18 Python
详解MySQL的主键查询为什么这么快
2022/04/03 MySQL