keras 自定义loss model.add_loss的使用详解


Posted in Python onJune 22, 2020

一点见解,不断学习,欢迎指正

1、自定义loss层作为网络一层加进model,同时该loss的输出作为网络优化的目标函数

from keras.models import Model
import keras.layers as KL
import keras.backend as K
import numpy as np
from keras.utils.vis_utils import plot_model
 
x_train=np.random.normal(1,1,(100,784))
 
x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
def custom_loss1(y_true,y_pred):
 return K.mean(K.abs(y_true-y_pred))
loss1=KL.Lambda(lambda x:custom_loss1(*x),name='loss1')([x,x_in])
 
model = Model(x_in, [loss1])
model.get_layer('loss1').output#取出loss
model.add_loss(loss1)#作为网络优化的目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
#
model.fit(x_train, None, epochs=5)

2、自定义loss,作为网络优化的目标函数

x_in = KL.Input(shape=(784,))
x = x_in
x = KL.Dense(100, activation='relu')(x)
x = KL.Dense(784, activation='sigmoid')(x)
 
model = Model(x_in, x)
loss = K.mean((x - x_in)**2)
model.add_loss(loss)#只是作为loss优化目标函数
model.compile(optimizer='adam')
plot_model(model,to_file='model.png',show_shapes=True)
model.fit(x_train, None, epochs=5)

补充知识:keras load_weights fine-tune

分享一个小技巧,就是在构建网络模型的时候,不要怕麻烦,给每一层都定义一个名字,这样在复用之前的参数权重的时候,除了官网给的先加载权重,再冻结权重之外,你可以通过简单的修改层的名字来达到加载之前训练的权重的目的,假设权重文件保存为model_pretrain.h5 ,重新使用的时候,我把想要复用的层的名字设置成一样的,然后

model.load_weights('model_pretrain.h5', by_name=True)

以上这篇keras 自定义loss model.add_loss的使用详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现查找excel里某一列重复数据并且剔除后打印的方法
May 26 Python
实现python版本的按任意键继续/退出
Sep 26 Python
Python设计模式之MVC模式简单示例
Jan 10 Python
python实现教务管理系统
Mar 12 Python
python pandas 对series和dataframe的重置索引reindex方法
Jun 07 Python
pandas 数据归一化以及行删除例程的方法
Nov 10 Python
python修改txt文件中的某一项方法
Dec 29 Python
Python骚操作之动态定义函数
Mar 26 Python
pytorch获取模型某一层参数名及参数值方式
Dec 30 Python
python中xlrd模块的使用详解
Feb 01 Python
python用字节处理文件实例讲解
Apr 13 Python
Python机器学习之基于Pytorch实现猫狗分类
Jun 08 Python
Python项目跨域问题解决方案
Jun 22 #Python
python os模块在系统管理中的应用
Jun 22 #Python
解决tensorflow读取本地MNITS_data失败的原因
Jun 22 #Python
python实现猜数游戏(保存游戏记录)
Jun 22 #Python
基于Tensorflow读取MNIST数据集时网络超时的解决方式
Jun 22 #Python
在Mac中配置Python虚拟环境过程解析
Jun 22 #Python
tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this T
Jun 22 #Python
You might like
全国FM电台频率大全 - 30 宁夏回族自治区
2020/03/11 无线电
php+ajax实现无刷新的新闻留言系统
2020/12/21 PHP
PHP面向对象编程之深入理解方法重载与方法覆盖(多态)
2015/12/24 PHP
PHP微信公众号自动发送红包API
2016/06/01 PHP
jQuery控制图片的hover效果(smartRollover.js)
2012/03/18 Javascript
jquery实现兼容浏览器的图片上传本地预览功能
2013/10/14 Javascript
jquery及原生js获取select下拉框选中的值示例
2013/10/25 Javascript
浅析tr的隐藏和显示问题
2014/03/05 Javascript
控制文字内容的显示与隐藏示例
2014/06/11 Javascript
JavaScript获取当前url根目录(路径)
2016/06/17 Javascript
文件上传,iframe跨域数据提交的实现
2016/11/18 Javascript
js模糊查询实例分享
2016/12/26 Javascript
详解vue与后端数据交互(ajax):vue-resource
2017/03/16 Javascript
jQuery中过滤器的基本用法示例
2017/10/11 jQuery
Vue.js 实现微信公众号菜单编辑器功能(一)
2018/05/08 Javascript
使用 vue-i18n 切换中英文效果
2018/05/23 Javascript
vue项目实现github在线预览功能
2018/06/20 Javascript
javascript防抖函数debounce详解
2019/06/11 Javascript
layer设置maxWidth及maxHeight解决方案
2019/07/26 Javascript
python标准日志模块logging的使用方法
2013/11/01 Python
Python脚本实现集群检测和管理功能
2015/03/06 Python
老生常谈python之鸭子类和多态
2017/06/13 Python
20个常用Python运维库和模块
2018/02/12 Python
pandas中apply和transform方法的性能比较及区别介绍
2018/10/30 Python
前端面试必备之CSS3的新特性
2017/09/05 HTML / CSS
HTML5边玩边学(1)画布实现方法
2010/09/21 HTML / CSS
什么是Oracle的后台进程background processes?都有哪些后台进程?
2012/04/26 面试题
年度考核自我评价
2014/01/25 职场文书
会计专业应届生自荐信
2014/02/07 职场文书
《白鹅》教学反思
2014/04/13 职场文书
推荐信模板
2014/05/09 职场文书
高一新生军训方案
2014/05/12 职场文书
党的生日活动方案
2014/08/15 职场文书
工作证明范本(2篇)
2014/09/14 职场文书
单位接收函范文
2015/01/30 职场文书
详解MongoDB的条件查询和排序
2021/06/23 MongoDB