浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用paramiko模块实现ssh远程登陆上传文件并执行
Jan 27 Python
python实现自动重启本程序的方法
Jul 09 Python
使用python调用zxing库生成二维码图片详解
Jan 10 Python
Linux下python与C++使用dlib实现人脸检测
Jun 29 Python
Python中super函数用法实例分析
Mar 18 Python
python3使用matplotlib绘制散点图
Mar 19 Python
django框架用户权限中的session缓存到redis中的方法
Aug 06 Python
对python中的装包与解包实例详解
Aug 24 Python
TensorFlow实现checkpoint文件转换为pb文件
Feb 10 Python
python 实现朴素贝叶斯算法的示例
Sep 30 Python
python实现简单猜单词游戏
Dec 24 Python
Python如何使用logging为Flask增加logid
Mar 30 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
PHP 生成N个不重复的随机数
2015/01/21 PHP
PHP 5.3和PHP 5.4出现FastCGI Error解决方法
2015/02/12 PHP
Yii2 中实现单点登录的方法
2018/03/09 PHP
动态修改DOM 里面的 id 属性的弊端分析
2008/09/03 Javascript
jquery 锁定弹出层实现代码
2010/02/23 Javascript
jquery $.ajax()取xml数据的小问题解决方法
2010/11/20 Javascript
js实现无需数据库的县级以上联动行政区域下拉控件
2013/08/14 Javascript
jquery获取一个元素下面相同子元素的个数代码
2014/07/31 Javascript
js仿苹果iwatch外观的计时器代码分享
2015/08/26 Javascript
使用jQuery监听DOM元素大小变化
2016/02/24 Javascript
Node.js实现文件上传
2016/07/05 Javascript
JCrop+ajaxUpload 图像切割上传的实例代码
2016/07/20 Javascript
jQuery的ajax和遍历数组json实例代码
2016/08/01 Javascript
javascript的几种写法总结
2016/09/30 Javascript
vue.js初学入门教程(1)
2016/11/03 Javascript
nodejs操作mysql实现增删改查的实例
2017/05/28 NodeJs
浅谈原型对象的常用开发模式
2017/07/22 Javascript
fastadmin中调用js的方法
2019/05/14 Javascript
梯度下降法介绍及利用Python实现的方法示例
2017/07/12 Python
网红编程语言Python将纳入高考你怎么看?
2018/06/07 Python
python集合是否可变总结
2019/06/20 Python
在Python中利用pickle保存变量的实例
2019/12/30 Python
将自己的数据集制作成TFRecord格式教程
2020/02/17 Python
幼儿教师工作感言
2014/02/14 职场文书
课前三分钟演讲稿
2014/04/24 职场文书
社会实践先进工作者事迹材料
2014/05/06 职场文书
2014年庆祝国庆65周年演讲稿
2014/09/21 职场文书
销售员工作检讨书(推荐篇)
2014/10/18 职场文书
综治维稳工作汇报
2014/10/27 职场文书
班主任自我评价范文
2015/03/11 职场文书
我在伊朗长大观后感
2015/06/16 职场文书
2016年三八节红领巾广播稿
2015/12/17 职场文书
python实现求纯色彩图像的边框
2021/04/08 Python
python基于tkinter制作m3u8视频下载工具
2021/04/24 Python
MySQL删除和插入数据很慢的问题解决
2021/06/03 MySQL
Win11如何启用启动修复 ? Win11执行启动修复的三种方法
2022/04/08 数码科技