浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
零基础写python爬虫之抓取糗事百科代码分享
Nov 06 Python
在主机商的共享服务器上部署Django站点的方法
Jul 22 Python
python实现斐波那契数列的方法示例
Jan 12 Python
python类的方法属性与方法属性的动态绑定代码详解
Dec 27 Python
Python基于Tkinter模块实现的弹球小游戏
Dec 27 Python
python 读取鼠标点击坐标的实例
Dec 29 Python
django mysql数据库及图片上传接口详解
Jul 18 Python
Python趣味入门教程之循环语句while
Aug 26 Python
python中用Scrapy实现定时爬虫的实例讲解
Jan 18 Python
Python使用UDP实现720p视频传输的操作
Apr 24 Python
python3读取文件指定行的三种方法
May 24 Python
Python Matplotlib绘制条形图的全过程
Oct 24 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
VOLVO车载收音机
2021/03/02 无线电
php实现对两个数组进行减法操作的方法
2015/04/17 PHP
PHP 访问数据库配置通用方法(json)
2018/05/20 PHP
laravel获取不到session的三种解决办法【推荐】
2018/09/16 PHP
从ThinkPHP3.2.3过渡到ThinkPHP5.0学习笔记图文详解
2019/04/03 PHP
javascript 有趣而诡异的数组
2009/04/06 Javascript
Jquery CheckBox全选方法代码附js checkbox全选反选代码
2010/06/09 Javascript
文本框获得焦点和失去焦点的判断代码
2012/03/18 Javascript
JavaScript实现班级随机点名小应用需求的具体分析
2014/05/12 Javascript
JS公共小方法之判断对象是否为domElement的实例
2016/11/25 Javascript
canvas绘制一个常用的emoji表情
2017/03/30 Javascript
node vue项目开发之前后端分离实战记录
2017/12/13 Javascript
vue-cli打包后本地运行dist文件中的index.html操作
2020/08/12 Javascript
uin-app+mockjs实现本地数据模拟
2020/08/26 Javascript
Windows中安装使用Virtualenv来创建独立Python环境
2016/05/31 Python
浅谈python 四种数值类型(int,long,float,complex)
2016/06/08 Python
Python实现破解12306图片验证码的方法分析
2017/12/29 Python
python安装模块如何通过setup.py安装(超简单)
2018/05/05 Python
Python利用requests模块下载图片实例代码
2019/08/12 Python
Python AutoCAD 系统设置的实现方法
2020/04/01 Python
appium+python自动化配置(adk、jdk、node.js)
2020/11/17 Python
html5简单示例_动力节点Java学院整理
2017/07/07 HTML / CSS
泰国时尚电商:POMELO Fashion
2020/03/11 全球购物
人事助理自荐信
2014/02/02 职场文书
五星级酒店餐饮部总监的标准岗位职责
2014/02/17 职场文书
新品发布会主持词
2014/04/02 职场文书
社会公德演讲稿
2014/05/20 职场文书
毕业生党员个人总结
2015/02/14 职场文书
护士求职自荐信范文
2015/03/04 职场文书
保险公司2016开门红口号集锦
2015/12/24 职场文书
2016年度优秀辅导员事迹材料
2016/02/26 职场文书
2016年社区党支部公开承诺书
2016/03/25 职场文书
pytorch实现ResNet结构的实例代码
2021/05/17 Python
解析高可用Redis服务架构分析与搭建方案
2021/06/20 Redis
python基础之模块的导入
2021/10/24 Python
win10音频服务未响应怎么解决?win10音频服务未响应未修复的解决方法
2022/08/14 数码科技