浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python爬虫中urllib库的进阶学习
Jan 05 Python
Python模拟简单电梯调度算法示例
Aug 20 Python
Django 中自定义 Admin 样式与功能的实现方法
Jul 04 Python
python 利用浏览器 Cookie 模拟登录的用户访问知乎的方法
Jul 11 Python
用django设置session过期时间的方法解析
Aug 05 Python
PyQt5基本控件使用之消息弹出、用户输入、文件对话框的使用方法
Aug 06 Python
python web框架中实现原生分页
Sep 08 Python
python 利用已有Ner模型进行数据清洗合并代码
Dec 24 Python
python求前n个阶乘的和实例
Apr 02 Python
python 在sql语句中使用%s,%d,%f说明
Jun 06 Python
Python基于wordcloud及jieba实现中国地图词云图
Jun 09 Python
python基于pygame实现飞机大作战小游戏
Nov 19 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
PHP第一季视频教程(李炎恢+php100 不断更新)
2011/05/29 PHP
PHP YII框架开发小技巧之模型(models)中rules自定义验证规则
2015/11/16 PHP
YII分模块加载路由的实现方法
2018/10/01 PHP
学习YUI.Ext第七日-View&JSONView Part Two-一个画室网站的案例
2007/03/10 Javascript
jQuery学习总结之元素的相对定位和选择器(持续更新)
2011/04/26 Javascript
jquery实现ajax提交form表单的方法总结
2014/03/03 Javascript
jQuery响应enter键的实现思路
2014/04/18 Javascript
jquery实现类似EasyUI的页面布局可改变左右的宽度
2020/09/12 Javascript
JavaScript中的6种运算符总结
2014/10/16 Javascript
Javascript模块化编程详解
2014/12/01 Javascript
在JS中a标签加入单击事件屏蔽href跳转页面
2016/12/16 Javascript
基于js中style.width与offsetWidth的区别(详解)
2017/11/12 Javascript
JS中使用textPath实现线条上的文字
2017/12/25 Javascript
浅析JS中回调函数及用法
2018/07/25 Javascript
解决vuejs 使用value in list 循环遍历数组出现警告的问题
2018/09/26 Javascript
nodejs基础之多进程实例详解
2018/12/27 NodeJs
jquery传参及获取方式(两种方式)
2020/02/13 jQuery
Vue双向绑定实现原理与方法详解
2020/05/07 Javascript
elementUI同一页面展示多个Dialog的实现
2020/11/19 Javascript
Python使用函数默认值实现函数静态变量的方法
2014/08/18 Python
Python记录详细调用堆栈日志的方法
2015/05/05 Python
Python实现豆瓣图片下载的方法
2015/05/25 Python
Ruby使用eventmachine为HTTP服务器添加文件下载功能
2016/04/20 Python
Python简单检测文本类型的2种方法【基于文件头及cchardet库】
2016/09/18 Python
django中模板的html自动转意方法
2018/05/27 Python
python分割一个文本为多个文本的方法
2019/07/22 Python
Python3的高阶函数map,reduce,filter的示例详解
2019/07/23 Python
Python 仅获取响应头, 不获取实体的实例
2019/08/21 Python
Python基于Dlib的人脸识别系统的实现
2020/02/26 Python
医院实习接收函
2014/01/12 职场文书
营销总经理岗位职责
2014/02/02 职场文书
《秋姑娘的信》教学反思
2014/02/28 职场文书
户籍证明书标准模板
2014/09/10 职场文书
MySQL 逻辑备份与恢复测试的相关总结
2021/05/14 MySQL
Python Flask请求扩展与中间件相关知识总结
2021/06/11 Python
Python语言内置数据类型
2022/02/24 Python