keras分类模型中的输入数据与标签的维度实例


Posted in Python onJuly 03, 2020

在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
  
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
 # Create an all-zero matrix of shape (len(sequences), dimension)
 results = np.zeros((len(sequences), dimension))
 for i, sequence in enumerate(sequences):
  results[i, sequence] = 1. # set specific indices of results[i] to 1s
 return results


x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data() 
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
    verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
        width_shift_range=5./32,
        height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
     steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
     callbacks=callbacks,
     validation_data=(testX, Y_test),
     validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单窗口倒计时界面实例
May 05 Python
django接入新浪微博OAuth的方法
Jun 29 Python
python安装与使用redis的方法
Apr 19 Python
Python文件操作基本流程代码实例
Dec 11 Python
python爬取足球直播吧五大联赛积分榜
Jun 13 Python
PIL图像处理模块paste方法简单使用详解
Jul 17 Python
tensorflow 模型权重导出实例
Jan 24 Python
Python 开发工具PyCharm安装教程图文详解(新手必看)
Feb 28 Python
python实现爱奇艺登陆密码RSA加密的方法示例详解
May 27 Python
Python描述数据结构学习之哈夫曼树篇
Sep 07 Python
Pytorch 使用tensor特定条件判断索引
Apr 08 Python
Python实现Hash算法
Mar 18 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
You might like
PHP 遍历文件实现代码
2011/05/04 PHP
php学习笔记 数组遍历实现代码
2011/06/09 PHP
PHP MVC框架路由学习笔记
2016/03/02 PHP
Symfony查询方法实例小结
2017/06/28 PHP
由prototype_1.3.1进入javascript殿堂-类的初探
2006/11/06 Javascript
javascript之大字符串的连接的StringBuffer 类
2007/05/08 Javascript
手把手教你自己写一个js表单验证框架的方法
2010/09/14 Javascript
『jQuery』名称冲突使用noConflict方法解决
2013/04/22 Javascript
jQuery插件实现表格隔行换色且感应鼠标高亮行变色
2013/09/22 Javascript
jquery ajax jsonp跨域调用实例代码
2013/12/11 Javascript
js跨域访问示例(客户端/服务端)
2014/05/19 Javascript
Nodejs实现批量下载妹纸图
2015/05/28 NodeJs
React学习笔记之条件渲染(一)
2017/07/02 Javascript
jQuery实现table表格信息的展开和缩小功能示例
2018/07/21 jQuery
mpvue+vant app搭建微信小程序的方法步骤
2019/02/11 Javascript
详解element-ui 表单校验 Rules 配置 常用黑科技
2020/07/11 Javascript
Openlayers实现测量功能
2020/09/25 Javascript
js 图片懒加载的实现
2020/10/21 Javascript
[36:13]Mineski vs iG 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python中xrange和range的区别
2014/05/13 Python
python决策树之CART分类回归树详解
2017/12/20 Python
python多线程之事件Event的使用详解
2018/04/27 Python
python range()函数取反序遍历sequence的方法
2018/06/25 Python
Python3利用print输出带颜色的彩色字体示例代码
2019/04/08 Python
Python 获取命令行参数内容及参数个数的实例
2019/12/20 Python
CSS3常用的几种颜色渐变模式总结
2016/11/18 HTML / CSS
马克华菲官方商城:Mark Fairwhale
2016/09/04 全球购物
H&M旗下高端女装品牌:& Other Stories
2018/05/07 全球购物
空指针到底是什么
2012/08/07 面试题
建材业务员岗位职责
2013/12/08 职场文书
学生爱国演讲稿
2014/01/14 职场文书
商业项目策划方案
2014/06/05 职场文书
放飞梦想演讲稿600字
2014/08/26 职场文书
社区灵活就业证明
2014/11/03 职场文书
2016秋季田径运动会广播稿
2015/12/21 职场文书
java设计模式--原型模式详解
2021/07/21 Java/Android