keras分类模型中的输入数据与标签的维度实例


Posted in Python onJuly 03, 2020

在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
  
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
 # Create an all-zero matrix of shape (len(sequences), dimension)
 results = np.zeros((len(sequences), dimension))
 for i, sequence in enumerate(sequences):
  results[i, sequence] = 1. # set specific indices of results[i] to 1s
 return results


x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data() 
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
    verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
        width_shift_range=5./32,
        height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
     steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
     callbacks=callbacks,
     validation_data=(testX, Y_test),
     validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python连接sql server乱码的解决方法
Jan 28 Python
下载给定网页上图片的方法
Feb 18 Python
Python制作数据导入导出工具
Jul 31 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 Python
Python3中的列表生成式、生成器与迭代器实例详解
Jun 11 Python
对python生成业务报表的实例详解
Feb 03 Python
python 利用pywifi模块实现连接网络破解wifi密码实时监控网络
Sep 16 Python
python代码实现将列表中重复元素之间的内容全部滤除
May 22 Python
Django之腾讯云短信的实现
Jun 12 Python
python调用私有属性的方法总结
Jul 24 Python
Matplotlib中%matplotlib inline如何使用
Jul 28 Python
Python中Schedule模块使用详解 周期任务神器
Apr 19 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
You might like
星际原理概述
2020/03/04 星际争霸
CI框架学习笔记(一) - 环境安装、基本术语和框架流程
2014/10/26 PHP
PHPExcel在linux环境下导出报500错误的解决方法
2017/01/26 PHP
使javascript也能包含文件
2006/10/26 Javascript
[转]JS宝典学习笔记
2007/02/07 Javascript
JavaScript Object的extend是一个常用的功能
2009/12/02 Javascript
基于prototype扩展的JavaScript常用函数库
2010/11/30 Javascript
web的各种前端打印方法之jquery打印插件PrintArea实现网页打印
2013/01/09 Javascript
jQuery满屏焦点图左右滚动特效代码分享
2015/09/07 Javascript
jQuery使用Layer弹出层插件闪退问题
2016/12/22 Javascript
原生JS实现圣旨卷轴展开效果
2017/03/06 Javascript
vue-cli+webpack记事本项目创建
2017/04/01 Javascript
详解nodejs模板引擎制作
2017/06/14 NodeJs
Vue2.0中集成UEditor富文本编辑器的方法
2018/03/03 Javascript
nodejs实现套接字服务功能详解
2018/06/21 NodeJs
vue实现的网易云音乐在线播放和下载功能案例
2019/02/18 Javascript
javascript function(函数类型)使用与注意事项小结
2019/06/10 Javascript
[01:31:02]TNC vs VG 2019国际邀请赛淘汰赛 胜者组赛BO3 第一场
2019/08/22 DOTA
tensorflow: 查看 tensor详细数值方法
2018/06/13 Python
Python基础之条件控制操作示例【if语句】
2019/03/23 Python
解决Python内层for循环如何break出外层的循环的问题
2019/06/24 Python
pandas实现to_sql将DataFrame保存到数据库中
2019/07/03 Python
Python随机函数库random的使用方法详解
2019/08/21 Python
pycharm Tab键设置成4个空格的操作
2021/02/26 Python
JD Sports意大利:英国篮球和运动时尚的领导者
2017/10/29 全球购物
台湾家适得:Homeget
2019/02/11 全球购物
zooplus波兰:在线宠物店
2019/07/21 全球购物
如何定义一个可复用的服务
2014/09/30 面试题
计算机系毕业生推荐信
2013/11/06 职场文书
法人授权委托书范本
2014/04/04 职场文书
品牌服务方案
2014/06/03 职场文书
2014感恩节演讲稿大全
2014/10/11 职场文书
入党申请书格式
2019/06/20 职场文书
python实现简单的聊天小程序
2021/07/07 Python
选购到合适的激光打印机
2022/04/21 数码科技
pytest实现多进程与多线程运行超好用的插件
2022/07/15 Python