keras分类模型中的输入数据与标签的维度实例


Posted in Python onJuly 03, 2020

在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
  
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
 # Create an all-zero matrix of shape (len(sequences), dimension)
 results = np.zeros((len(sequences), dimension))
 for i, sequence in enumerate(sequences):
  results[i, sequence] = 1. # set specific indices of results[i] to 1s
 return results


x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data() 
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
    verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
        width_shift_range=5./32,
        height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
     steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
     callbacks=callbacks,
     validation_data=(testX, Y_test),
     validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用reportlab将目录下所有的文本文件打印成pdf的方法
May 20 Python
用Python写飞机大战游戏之pygame入门(4):获取鼠标的位置及运动
Nov 05 Python
深入理解Django的中间件middleware
Mar 14 Python
Python matplotlib 画图窗口显示到gui或者控制台的实例
May 24 Python
解决python os.mkdir创建目录失败的问题
Oct 16 Python
Windows 8.1 64bit下搭建 Scrapy 0.22 环境
Nov 18 Python
Python Scapy随心所欲研究TCP协议栈
Nov 20 Python
python远程调用rpc模块xmlrpclib的方法
Jan 11 Python
python中的colorlog库使用详解
Jul 05 Python
python GUI库图形界面开发之PyQt5窗口背景与不规则窗口实例
Feb 25 Python
Django-Scrapy生成后端json接口的方法示例
Oct 06 Python
python空元组在all中返回结果详解
Dec 15 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
You might like
PHP array_multisort()函数的使用札记
2011/07/03 PHP
PHP安全防范技巧分享
2011/11/03 PHP
php中常量DIRECTORY_SEPARATOR用法深入分析
2014/11/14 PHP
网站防止被刷票的一些思路与方法
2015/01/08 PHP
PHP实现批量生成App各种尺寸Logo
2015/03/19 PHP
PHP策略模式定义与用法示例
2017/07/27 PHP
PHP实现文件上传操作和封装
2020/03/04 PHP
Laravel框架下的Contracts契约详解
2020/03/17 PHP
jquery下操作HTML控件的实现代码
2010/01/12 Javascript
基于jsTree的无限级树JSON数据的转换代码
2010/07/27 Javascript
javascript实现的一个随机点名功能
2014/08/26 Javascript
JavaScript列表框listbox全选和反选的实现方法
2015/03/18 Javascript
微信小程序 获取当前地理位置和经纬度实例代码
2016/12/05 Javascript
深入nodejs中流(stream)的理解
2017/03/27 NodeJs
深入理解vue中的$set
2017/06/01 Javascript
详解angular ui-grid之过滤器设置
2017/06/07 Javascript
Node.js 回调函数实例详解
2017/07/06 Javascript
在iframe中使bootstrap的模态框在父页面弹出问题
2017/08/07 Javascript
JS解决IOS中拍照图片预览旋转90度BUG的问题
2017/09/13 Javascript
Bootstrap 模态框多次显示后台提交多次BUG的解决方法
2017/12/26 Javascript
浅谈Node 调试工具入门教程
2018/03/20 Javascript
JS代码简洁方式之函数方法详解
2020/07/28 Javascript
查看django执行的sql语句及消耗时间的两种方法
2018/05/29 Python
python导入pandas具体步骤方法
2019/06/23 Python
python列表每个元素同增同减和列表元素去空格的实例
2019/07/20 Python
python实现批量修改文件名
2020/03/23 Python
大学生自我鉴定范文
2013/12/28 职场文书
商场消防管理制度
2014/01/12 职场文书
护士节活动总结
2014/08/29 职场文书
副乡长群众路线教育实践活动个人对照检查材料
2014/09/19 职场文书
2016年教师政治思想表现评语
2015/12/02 职场文书
《七月的天山》教学反思
2016/02/19 职场文书
2016年保险公众宣传日活动总结
2016/04/05 职场文书
JavaScript最完整的深浅拷贝实现方式详解
2022/02/28 Javascript
WCG2010 星际争霸决赛 Flash vs Goojila 1 星际经典比赛回顾
2022/04/01 星际争霸
JS轻量级函数式编程实现XDM二
2022/06/16 Javascript