浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python的Django框架实现事务交易管理的教程
Apr 20 Python
Python使用ftplib实现简易FTP客户端的方法
Jun 03 Python
基于python yield机制的异步操作同步化编程模型
Mar 18 Python
Python进行数据提取的方法总结
Aug 22 Python
Python中遇到的小问题及解决方法汇总
Jan 11 Python
Python输入二维数组方法
Apr 13 Python
解决Django migrate No changes detected 不能创建表的问题
May 27 Python
Python实现监控键盘鼠标操作示例【基于pyHook与pythoncom模块】
Sep 04 Python
Python 通过打码平台实现验证码的实现
May 13 Python
在SQLite-Python中实现返回、查询中文字段的方法
Jul 17 Python
Python程序慢的重要原因
Sep 04 Python
python Gabor滤波器讲解
Oct 26 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
php中取得URL的根域名的代码
2011/03/23 PHP
PHP与javascript实现变量交互的示例代码
2013/07/23 PHP
PHP批量检测并去除文件BOM头代码实例
2014/05/08 PHP
PHP和javascript常用正则表达式及用法实例
2014/07/01 PHP
Laravel中的Blade模板引擎示例详解
2017/10/10 PHP
在Ajax中使用Flash实现跨域数据读取的实现方法
2010/12/02 Javascript
jQuery学习笔记之控制页面实现代码
2012/02/27 Javascript
Javascript图像处理思路及实现代码
2012/12/25 Javascript
终于解决了IE8不支持数组的indexOf方法
2013/04/03 Javascript
使用jQuery UI的tooltip函数修饰title属性的气泡悬浮框
2013/06/24 Javascript
如何学习Javascript入门指导
2013/11/01 Javascript
jquery超简单实现手风琴效果的方法
2015/06/05 Javascript
Jquery 全选反选实例代码
2015/11/19 Javascript
Nodejs从有门道无门菜鸟起飞必看教程
2016/07/20 NodeJs
关于数据与后端进行交流匹配(点亮星星)
2016/08/03 Javascript
jQuery Ajax使用FormData对象上传文件的方法
2016/09/07 Javascript
xcode中获取js文件的路径方法(推荐)
2016/11/05 Javascript
js实现淡入淡出轮播切换功能
2017/01/13 Javascript
微信小程序实现换肤功能
2018/03/14 Javascript
vue对storejs获取的数据进行处理时遇到的几种问题小结
2018/03/20 Javascript
每天学点Vue源码之vm.$mount挂载函数
2019/03/11 Javascript
vue实现多条件和模糊搜索功能
2019/05/28 Javascript
JS回调函数原理与用法详解【附PHP回调函数】
2019/07/20 Javascript
详解Vue中CSS样式穿透问题
2019/09/12 Javascript
Python金融数据可视化汇总
2017/11/17 Python
使用Python自动生成HTML的方法示例
2019/08/06 Python
Python使用Tkinter实现转盘抽奖器的步骤详解
2020/01/06 Python
Python 实现向word(docx)中输出
2020/02/13 Python
利用PyQt5+Matplotlib 绘制静态/动态图的实现代码
2020/07/13 Python
python 爬取百度文库并下载(免费文章限定)
2020/12/04 Python
英国最大的独立玩具专卖店:The Entertainer
2019/09/06 全球购物
高中军训感想800字
2014/02/23 职场文书
应用外语系自荐信
2014/06/26 职场文书
2015年感恩节演讲稿(优选篇)
2015/03/20 职场文书
幼儿园保教工作总结2015
2015/10/15 职场文书
Nginx利用Logrotate实现日志分割
2022/05/20 Servers