keras分类模型中的输入数据与标签的维度实例


Posted in Python onJuly 03, 2020

在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
  
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
 # Create an all-zero matrix of shape (len(sequences), dimension)
 results = np.zeros((len(sequences), dimension))
 for i, sequence in enumerate(sequences):
  results[i, sequence] = 1. # set specific indices of results[i] to 1s
 return results


x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data() 
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
    verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
        width_shift_range=5./32,
        height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
     steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
     callbacks=callbacks,
     validation_data=(testX, Y_test),
     validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pycharm 使用心得(三)Hello world!
Jun 05 Python
用Python输出一个杨辉三角的例子
Jun 13 Python
Python中绑定与未绑定的类方法用法分析
Apr 29 Python
Python获取文件所在目录和文件名的方法
Jan 12 Python
Python实现识别手写数字 Python图片读入与处理
Mar 23 Python
windows下cx_Freeze生成Python可执行程序的详细步骤
Oct 09 Python
对python多线程与global变量详解
Nov 09 Python
Python设计模式之命令模式原理与用法实例分析
Jan 11 Python
Python之pymysql的使用小结
Jul 01 Python
Python实现隐马尔可夫模型的前向后向算法的示例代码
Dec 31 Python
Python通过socketserver处理多个链接
Mar 18 Python
Python Opencv中用compareHist函数进行直方图比较对比图片
Apr 07 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
You might like
弄了个检测传输的参数是否为数字的Function
2006/12/06 PHP
真正的ZIP文件操作类(php)
2007/07/21 PHP
zen cart新进商品的随机排序修改方法
2010/09/10 PHP
ThinkPHP控制器详解
2015/07/27 PHP
php socket通信简单实现
2016/11/18 PHP
thinkPHP中验证码的简单实现方法
2016/12/05 PHP
workerman结合laravel开发在线聊天应用的示例代码
2018/10/30 PHP
论坛特效代码收集(落伍转发-不错)
2006/12/02 Javascript
js实现权限树的更新权限时的全选全消功能
2009/02/17 Javascript
javascript 获取页面的高度及滚动条的位置的代码
2010/05/06 Javascript
通过Javascript将数据导出到外部Excel文档的函数代码
2012/06/15 Javascript
Javascript开发之三数组对象实例介绍
2012/11/12 Javascript
jquery中交替点击事件toggle方法的使用示例
2013/12/08 Javascript
Javascript全局变量var与不var的区别深入解析
2013/12/09 Javascript
Javascript实现颜色rgb与16进制转换的方法
2015/04/18 Javascript
Extjs让combobox写起来简洁又漂亮
2017/01/05 Javascript
Bootstrap实现提示框和弹出框效果
2017/01/11 Javascript
Mongoose经常返回e11000 error的原因分析
2017/03/29 Javascript
详解webpack+vue-cli项目打包技巧
2017/06/17 Javascript
ng-events类似ionic中Events的angular全局事件
2018/09/05 Javascript
编写v-for循环的技巧汇总
2020/12/01 Javascript
[02:39]DOTA2英雄基础教程 天怒法师
2013/11/29 DOTA
Python 连连看连接算法
2008/11/22 Python
Python获取Linux系统下的本机IP地址代码分享
2014/11/07 Python
Python使用metaclass实现Singleton模式的方法
2015/05/05 Python
Python request设置HTTPS代理代码解析
2018/02/12 Python
python基础教程项目二之画幅好画
2018/04/02 Python
pandas基于时间序列的固定时间间隔求均值的方法
2019/07/04 Python
PyTorch中的C++扩展实现
2020/04/02 Python
英国运动风奢侈品购物网站:Maison De Fashion
2020/08/28 全球购物
生物医学工程专业学生求职信范文分享
2013/12/14 职场文书
关于读书的演讲稿800字
2014/08/27 职场文书
自荐信格式模板
2015/03/27 职场文书
围城读书笔记
2015/06/26 职场文书
助学金申请书该怎么写?
2019/07/16 职场文书
NGINX 权限控制文件预览和下载的实现原理
2022/01/18 Servers