浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的多线程http压力测试代码
Feb 08 Python
Python 爬虫之超链接 url中含有中文出错及解决办法
Aug 03 Python
对python GUI实现完美进度条的示例详解
Dec 13 Python
python对象与json相互转换的方法
May 07 Python
python excel转换csv代码实例
Aug 26 Python
Python列表切片常用操作实例解析
Dec 16 Python
python用WxPython库实现无边框窗体和透明窗体实现方法详解
Feb 21 Python
Tensorflow中的图(tf.Graph)和会话(tf.Session)的实现
Apr 22 Python
django queryset相加和筛选教程
May 18 Python
python 引用传递和值传递详解(实参,形参)
Jun 05 Python
python cv2.resize函数high和width注意事项说明
Jul 05 Python
Pytorch实现WGAN用于动漫头像生成
Mar 04 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
PHP Global定义全局变量使用说明
2013/08/15 PHP
一个PHP二维数组排序的函数分享
2014/01/17 PHP
用PHP解决的一个栈的面试题
2014/07/02 PHP
CodeIgniter扩展核心类实例详解
2016/01/20 PHP
PHP filesize函数用法浅析
2019/02/15 PHP
用js得到网页中所有的div的id
2020/10/19 Javascript
javascript instanceof,typeof的区别
2010/03/24 Javascript
Jquery ajaxsubmit上传图片实现代码
2010/11/04 Javascript
jQuery 瀑布流 绝对定位布局(二)(延迟AJAX加载图片)
2012/05/23 Javascript
Javascript MVC框架Backbone.js详解
2014/09/18 Javascript
JS加载器如何动态加载外部js文件
2016/05/26 Javascript
AngularJS监听路由的变化示例代码
2016/09/23 Javascript
jQuery实现的购物车物品数量加减功能代码
2016/11/16 Javascript
jQuery自定义元素右键点击事件(实现案例)
2017/04/28 jQuery
BootStrap表单时间选择器详解
2017/05/09 Javascript
Javascript实现购物车功能的详细代码
2018/05/08 Javascript
vue cli2.0单页面title修改方法
2018/06/07 Javascript
JavaScript字符串处理常见操作方法小结
2019/11/15 Javascript
JS操作Fckeditor的一些常用方法(获取、插入等)
2020/02/19 Javascript
[03:02]生活中的Dendi之野外度假篇
2016/08/09 DOTA
[38:21]2018DOTA2亚洲邀请赛3月30日 小组赛A组 LGD VS Newbee
2018/03/31 DOTA
详解Python3 中hasattr()、getattr()、setattr()、delattr()函数及示例代码数
2018/04/18 Python
浅谈python的dataframe与series的创建方法
2018/11/12 Python
python多任务及返回值的处理方法
2019/01/22 Python
pycharm创建scrapy项目教程及遇到的坑解析
2019/08/15 Python
python list转置和前后反转的例子
2019/08/26 Python
Python多线程多进程实例对比解析
2020/03/12 Python
python如何从键盘获取输入实例
2020/06/18 Python
浅谈tensorflow使用张量时的一些注意点tf.concat,tf.reshape,tf.stack
2020/06/23 Python
html5 Canvas画图教程(2)—画直线与设置线条的样式如颜色/端点/交汇点
2013/01/09 HTML / CSS
施华洛世奇意大利官网:SWAROVSKI意大利
2018/07/23 全球购物
Unix如何添加新的用户
2014/08/20 面试题
小学生环保倡议书
2014/05/15 职场文书
护理专业自荐书
2014/06/04 职场文书
注册资产评估专业求职信
2014/07/16 职场文书
【2·13】一图读懂中国无线电发展
2022/02/18 无线电