浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中django框架通过正则搜索页面上email地址的方法
Mar 21 Python
python使用fileinput模块实现逐行读取文件的方法
Apr 29 Python
python基础练习之几个简单的游戏
Nov 10 Python
Python竟能画这么漂亮的花,帅呆了(代码分享)
Nov 15 Python
PyQt5每天必学之QSplitter实现窗口分隔
Apr 19 Python
Python简单计算给定某一年的某一天是星期几示例
Jun 27 Python
这可能是最好玩的python GUI入门实例(推荐)
Jul 19 Python
Python使用QQ邮箱发送邮件报错smtplib.SMTPAuthenticationError
Dec 20 Python
python获取依赖包和安装依赖包教程
Feb 13 Python
关于tf.TFRecordReader()函数的用法解析
Feb 17 Python
python编程的核心知识点总结
Feb 08 Python
在python中实现导入一个需要传参的模块
May 12 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
一个取得文件扩展名的函数
2006/10/09 PHP
PHP+AJAX实现无刷新注册(带用户名实时检测)
2006/12/02 PHP
Zend framework处理一个http请求的流程分析
2010/02/08 PHP
php中in_array函数用法分析
2014/11/15 PHP
Zend Framework入门之环境配置及第一个Hello World示例(附demo源码下载)
2016/03/21 PHP
JavaScipt基本教程之JavaScript语言的基础
2008/01/16 Javascript
javascript fullscreen全屏实现代码
2009/04/09 Javascript
JS 树形递归实例代码
2010/05/18 Javascript
jQuery中add实现同时选择两个id对象
2010/10/22 Javascript
innerHTML与jquery里的html()区别介绍
2012/10/12 Javascript
formvalidator验证插件中有关ajax验证问题
2013/01/04 Javascript
js实现div层缓慢收缩与展开的方法
2015/05/11 Javascript
js实现文字在按钮上滚动的方法
2015/08/20 Javascript
Bootstrap~多级导航(级联导航)的实现效果【附代码】
2016/03/08 Javascript
JS中使用textPath实现线条上的文字
2017/12/25 Javascript
JavaScript ES6常用基础知识总结
2019/02/09 Javascript
KnockoutJS数组比较算法实例详解
2019/11/25 Javascript
numpy.random.seed()的使用实例解析
2018/02/03 Python
python实现微信每日一句自动发送给喜欢的人
2019/04/29 Python
使用Python和Prometheus跟踪天气的使用方法
2019/05/06 Python
用Python从0开始实现一个中文拼音输入法的思路详解
2019/07/20 Python
python实现抠图给证件照换背景源码
2019/08/20 Python
浅谈tensorflow使用张量时的一些注意点tf.concat,tf.reshape,tf.stack
2020/06/23 Python
python报错: 'list' object has no attribute 'shape'的解决
2020/07/15 Python
CSS3动画之流彩文字效果+图片模糊效果+边框伸展效果实现代码合集
2017/08/18 HTML / CSS
TripAdvisor斯洛伐克:阅读评论、比较价格和酒店预订
2018/04/25 全球购物
皇家阿尔伯特瓷器美国官网:Royal Albert美国
2020/02/16 全球购物
伊莱克斯(Electrolux)俄罗斯网上商店:瑞典家用电器品牌
2021/01/23 全球购物
简述你对Statement,PreparedStatement,CallableStatement的理解
2013/03/25 面试题
什么是用户模式(User Mode)与内核模式(Kernel Mode) ?
2014/07/21 面试题
女大学生自我鉴定
2013/12/09 职场文书
车间副主任岗位职责
2013/12/24 职场文书
教师个人的自我评价分享
2014/01/02 职场文书
在校大学生的职业生涯规划书
2014/03/14 职场文书
法制教育演讲稿
2014/09/10 职场文书
2016领导干部廉洁自律心得体会
2016/01/13 职场文书