浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python编写一个每天都在系统下新建一个文件夹的脚本
May 04 Python
python修改操作系统时间的方法
May 18 Python
Python冒泡排序注意要点实例详解
Sep 09 Python
Apache如何部署django项目
May 21 Python
Python利用QQ邮箱发送邮件的实现方法(分享)
Jun 09 Python
Python使用flask框架操作sqlite3的两种方式
Jan 31 Python
Python generator生成器和yield表达式详解
Aug 08 Python
python实现超市商品销售管理系统
Oct 25 Python
Python 3 使用Pillow生成漂亮的分形树图片
Dec 24 Python
计算Python Numpy向量之间的欧氏距离实例
May 22 Python
tensorflow 动态获取 BatchSzie 的大小实例
Jun 30 Python
Python排序算法之插入排序及其优化方案详解
Jun 11 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
php foreach循环中使用引用的问题
2013/11/06 PHP
php metaphone()函数的定义和用法
2016/05/15 PHP
图片完美缩放
2006/09/07 Javascript
JQuery自定义事件的应用 JQuery最佳实践
2010/08/01 Javascript
JavaScript 学习笔记之基础中的基础
2015/01/13 Javascript
JS中页面与页面之间超链接跳转中文乱码问题的解决办法
2016/12/15 Javascript
快速实现JS图片懒加载(可视区域加载)示例代码
2017/01/04 Javascript
node.js中grunt和gulp的区别详解
2017/07/17 Javascript
微信小程序 配置顶部导航条标题颜色的实现方法
2017/09/20 Javascript
node通过npm写一个cli命令行工具
2017/10/12 Javascript
Vue一个案例引发的递归组件的使用详解
2018/11/15 Javascript
微信小程序实现禁止分享代码实例
2019/10/19 Javascript
Vue数字输入框组件的使用方法
2019/10/19 Javascript
vue穿梭框实现上下移动
2021/01/29 Vue.js
[01:01:22]VGJ.S vs OG 2018国际邀请赛淘汰赛BO3 第一场 8.22
2018/08/23 DOTA
Python中处理字符串之isalpha()方法的使用
2015/05/18 Python
Python中的一些陷阱与技巧小结
2015/07/10 Python
详解Python中的Descriptor描述符类
2016/06/14 Python
python递归查询菜单并转换成json实例
2017/03/27 Python
浅谈python import引入不同路径下的模块
2017/07/11 Python
python计算auc指标实例
2017/07/13 Python
python生成1行四列全2矩阵的方法
2018/08/04 Python
详解python中递归函数
2019/04/16 Python
OpenCV3.0+Python3.6实现特定颜色的物体追踪
2019/07/23 Python
python实现梯度下降法
2020/03/24 Python
Python实现Appium端口检测与释放的实现
2020/12/31 Python
CSS3实现淘宝留白的方法
2020/06/05 HTML / CSS
世界领先的在线地板和建筑材料批发商:BuildDirect
2017/02/26 全球购物
彼得罗夫美国官网:Peter Thomas Roth美国(青瓜面膜)
2017/11/05 全球购物
HelloFresh奥地利:立即订购烹饪盒
2019/02/22 全球购物
化学相关工作求职信
2013/10/02 职场文书
门诊手术室工作制度
2014/01/30 职场文书
电脑售后服务承诺书
2014/03/27 职场文书
欢迎词怎么写
2015/01/23 职场文书
2016中学教师读书心得体会
2016/01/13 职场文书
nginx反向代理配置去除前缀案例教程
2021/07/26 Servers