浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 中文字符串的处理实现代码
Oct 25 Python
Python实现读取目录所有文件的文件名并保存到txt文件代码
Nov 22 Python
Python二分查找详解
Sep 13 Python
基于python爬虫数据处理(详解)
Jun 10 Python
python实现学生信息管理系统
Apr 05 Python
Python Pexpect库的简单使用方法
Jan 29 Python
Pyqt5如何让QMessageBox按钮显示中文示例代码
Apr 11 Python
微信公众号token验证失败解决方案
Jul 22 Python
解决python3 安装不了PIL的问题
Aug 16 Python
解决tensorflow由于未初始化变量而导致的错误问题
Jan 06 Python
python字典和json.dumps()的遇到的坑分析
Mar 11 Python
用Python制作mini翻译器的实现示例
Aug 17 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
Php部分常见问题总结
2006/10/09 PHP
PHP 数组教程 定义数组
2009/10/23 PHP
PHP中的错误处理、异常处理机制分析
2012/05/07 PHP
PHP安装threads多线程扩展基础教程
2015/11/17 PHP
php实现当前页面点击下载文件的实例代码
2016/11/16 PHP
SlideView 图片滑动(扩展/收缩)展示效果
2010/08/01 Javascript
javascript跑马灯悬停放大效果实现代码
2012/12/12 Javascript
javaScript实现浮点数转十六进制字符
2013/10/29 Javascript
JQuery弹出炫丽对话框的同时让背景变灰色
2014/05/22 Javascript
JS实现的通用表单验证插件完整实例
2015/08/20 Javascript
基于jQuery的网页影音播放器jPlayer的基本使用教程
2016/03/08 Javascript
jQuery实现自动输入email、时间和域名的方法
2016/08/24 Javascript
js Canvas绘制圆形时钟教程
2017/02/06 Javascript
详解Bootstrap 学习(一)入门
2019/04/12 Javascript
jQuery内容选择器与表单选择器实例分析
2019/06/28 jQuery
重学JS之显示强制类型转换详解
2019/06/30 Javascript
BootstrapValidator验证用户名已存在(ajax)
2019/11/08 Javascript
原生JavaScript实现留言板
2021/01/10 Javascript
python连接mysql调用存储过程示例
2014/03/05 Python
Python实现子类调用父类的方法
2014/11/10 Python
在Python的Django框架的视图中使用Session的方法
2015/07/23 Python
基于python list对象中嵌套元组使用sort时的排序方法
2018/04/18 Python
python Pandas库基础分析之时间序列的处理详解
2019/07/13 Python
django 配置阿里云OSS存储media文件的例子
2019/08/20 Python
在OpenCV里实现条码区域识别的方法示例
2019/12/04 Python
python爬虫开发之urllib模块详细使用方法与实例全解
2020/03/09 Python
Yves Rocher伊夫·黎雪美国官网:法国始创植物美肌1959
2019/01/09 全球购物
军训自我鉴定范文
2014/02/13 职场文书
2014年党员公开承诺书范文
2014/03/28 职场文书
企业仓管员岗位职责
2014/06/15 职场文书
2014教育局对照检查材料思想汇报
2014/09/23 职场文书
车辆年检委托书范本
2014/10/14 职场文书
2015年端午节活动策划书
2015/05/05 职场文书
少先队入队仪式主持词
2015/07/04 职场文书
python中的被动信息搜集
2021/04/29 Python
浅谈MySQL表空间回收的正确姿势
2021/10/05 MySQL