浅谈keras中loss与val_loss的关系


Posted in Python onJune 22, 2020

loss函数如何接受输入值

keras封装的比较厉害,官网给的例子写的云里雾里,

在stackoverflow找到了答案

You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).

def custom_loss_wrapper(input_tensor):
 def custom_loss(y_true, y_pred):
  return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
 return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')

You can verify that input_tensor and the loss value will change as different X is passed to the model.

X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642

X *= 1000
model.test_on_batch(X, y) # => 511.15466

fit_generator

fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.

Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)

### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)

补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签

我就废话不多说了,大家还是直接看代码吧~

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()

train_images = train_images / 255.0
test_images = test_images / 255.0

# plt.figure(figsize=(10,10))
# for i in range(25):
#  plt.subplot(5,5,i+1)
#  plt.xticks([])
#  plt.yticks([])
#  plt.grid(False)
#  plt.imshow(train_images[i], cmap=plt.cm.binary)
#  plt.xlabel(class_names[train_labels[i]])
# plt.show()

model = keras.Sequential([
 keras.layers.Flatten(input_shape=(28, 28)),
 keras.layers.Dense(128, activation='relu'),
 keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
    loss='categorical_crossentropy', 
    #loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
    metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)

model.fit(train_images, one_hot_train_labels, epochs=10)

one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)

print('\nTest accuracy:', test_acc)

# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]

loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,

而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可

以上这篇浅谈keras中loss与val_loss的关系就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用xauth方式登录饭否网然后发消息
Apr 11 Python
浅谈django rest jwt vue 跨域问题
Oct 26 Python
python多线程并发让两个LED同时亮的方法
Feb 18 Python
在Python中如何传递任意数量的实参的示例代码
Mar 21 Python
Tensorflow 定义变量,函数,数值计算等名字的更新方式
Feb 10 Python
tensorboard显示空白的解决
Feb 15 Python
浅谈Python程序的错误:变量未定义
Jun 02 Python
opencv 实现特定颜色线条提取与定位操作
Jun 02 Python
python如何编写win程序
Jun 08 Python
matplotlib实现数据实时刷新的示例代码
Jan 05 Python
pytorch 实现在测试的时候启用dropout
May 27 Python
Python find()、rfind()方法及作用
Dec 24 Python
python实现简易版学生成绩管理系统
Jun 22 #Python
python能否java成为主流语言吗
Jun 22 #Python
python让函数不返回结果的方法
Jun 22 #Python
python实现学生成绩测评系统
Jun 22 #Python
python算的上脚本语言吗
Jun 22 #Python
Python读取二进制文件代码方法解析
Jun 22 #Python
怎么快速自学python
Jun 22 #Python
You might like
利用 window_onload 实现select默认选择
2006/10/09 PHP
php2html php生成静态页函数
2008/12/08 PHP
PHP 输出简单动态WAP页面
2009/06/09 PHP
10条PHP高级技巧[修正版]
2011/08/02 PHP
thinkPHP实现的联动菜单功能详解
2017/05/05 PHP
PHP实现二叉树深度优先遍历(前序、中序、后序)和广度优先遍历(层次)实例详解
2018/04/20 PHP
日期 时间js控件
2009/05/07 Javascript
jQuery ajax在GBK编码下表单提交终极解决方案(非二次编码方法)
2010/10/20 Javascript
JS获取图片实际宽高及根据图片大小进行自适应
2013/08/11 Javascript
基于jquery实现等比缩放图片
2014/12/03 Javascript
node.js中的http.response.addTrailers方法使用说明
2014/12/14 Javascript
jquery插件unobtrusive实现片段式加载
2015/06/15 Javascript
基于javascript代码实现通过点击图片显示原图片
2015/11/29 Javascript
NodeJS远程代码执行
2016/08/28 NodeJs
基于vue-router 多级路由redirect 重定向的问题
2018/09/03 Javascript
Vue引用Swiper4插件无法重写分页器样式的解决方法
2018/09/27 Javascript
详解webpack之图片引入-增强的file-loader:url-loader
2018/10/08 Javascript
Vue CLI项目 axios模块前后端交互的使用(类似ajax提交)
2019/09/01 Javascript
vue element-ui实现动态面包屑导航
2019/12/23 Javascript
Python中使用dom模块生成XML文件示例
2015/04/05 Python
Python之Web框架Django项目搭建全过程
2017/05/02 Python
python使用pyqt写带界面工具的示例代码
2017/10/23 Python
python中pip的安装与使用教程
2018/08/10 Python
dataframe 按条件替换某一列中的值方法
2019/01/29 Python
Python生成指定数量的优惠码实操内容
2019/06/18 Python
pytorch打印网络结构的实例
2019/08/19 Python
详解Python3定时器任务代码
2019/09/23 Python
使用tensorflow实现VGG网络,训练mnist数据集方式
2020/05/26 Python
使用Python Tkinter实现剪刀石头布小游戏功能
2020/10/23 Python
.net软件工程师应聘上机试题
2015/03/10 面试题
简短的公司员工自我评价分享
2013/11/13 职场文书
暑期社会实践学生的自我评价
2014/01/09 职场文书
《金钱的魔力》教学反思
2014/02/24 职场文书
房务中心文员岗位职责
2014/04/16 职场文书
教务处干事工作总结
2015/08/14 职场文书
党员干部学习十八届五中全会精神心得体会
2016/01/05 职场文书