浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python的re模块应用实例
Sep 26 Python
python解决汉字编码问题:Unicode Decode Error
Jan 19 Python
VScode编写第一个Python程序HelloWorld步骤
Apr 06 Python
使用Python处理BAM的方法
Sep 28 Python
使用celery执行Django串行异步任务的方法步骤
Jun 06 Python
Python中的延迟绑定原理详解
Oct 11 Python
python中count函数简单的实例讲解
Feb 06 Python
python统计字符串中字母出现次数代码实例
Mar 02 Python
使用OpenCV获取图像某点的颜色值,并设置某点的颜色
Jun 02 Python
pytorch MSELoss计算平均的实现方法
May 12 Python
Pytorch数据读取之Dataset和DataLoader知识总结
May 23 Python
Flask使用SQLAlchemy实现持久化数据
Jul 16 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
php绘制一个扇形的方法
2015/01/24 PHP
PHP会话处理的10个函数
2015/08/11 PHP
PHP正则删除HTML代码中宽高样式的方法
2017/06/12 PHP
犀利的js 函数集合
2009/06/11 Javascript
js对象的比较
2011/02/26 Javascript
解析John Resig Simple JavaScript Inheritance代码
2012/12/03 Javascript
基于jquery实现百度新闻导航菜单滑动动画
2016/03/15 Javascript
AngularJS 中的事件详解
2016/07/28 Javascript
如何利用JSHint减少JavaScript的错误
2016/08/23 Javascript
针对后台列表table拖拽比较实用的jquery拖动排序
2016/10/10 Javascript
js HTML5上传示例代码完整版
2016/10/10 Javascript
js模拟微博发布消息
2017/02/23 Javascript
jQuery实现选项卡功能(两种方法)
2017/03/08 Javascript
Vuex之理解Store的用法
2017/04/19 Javascript
微信小程序本地缓存数据增删改查实例详解
2017/05/24 Javascript
Angular 4 指令快速入门教程
2017/06/07 Javascript
JavaScript中递归实现的方法及其区别
2017/09/12 Javascript
解析vue中的$mount
2017/12/21 Javascript
小程序开发踩坑:页面窗口定位(相对于浏览器定位)(推荐)
2019/04/25 Javascript
记录vue做微信自定义分享的一些问题
2019/09/12 Javascript
[34:10]Secret vs VG 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.24
2019/09/10 DOTA
Python translator使用实例
2008/09/06 Python
深入了解Python数据类型之列表
2016/06/24 Python
解决Python pandas df 写入excel 出现的问题
2018/07/04 Python
python super的使用方法及实例详解
2019/09/25 Python
python 上下文管理器及自定义原理解析
2019/11/19 Python
详解python百行有效代码实现汉诺塔小游戏(简约版)
2020/10/30 Python
Linux中如何设置Java环境变量(Ubuntu)
2016/07/24 面试题
什么是命名空间(NameSpace)
2015/11/24 面试题
办公室主任岗位职责
2015/01/31 职场文书
横店影视城导游词
2015/02/06 职场文书
中标通知书
2015/04/17 职场文书
家庭贫困证明
2015/06/16 职场文书
倡议书怎么写?
2019/04/11 职场文书
Linux安装Nginx步骤详解
2021/03/31 Servers
SQL语句中EXISTS的详细用法大全
2022/06/25 MySQL