keras分类模型中的输入数据与标签的维度实例


Posted in Python onJuly 03, 2020

在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
  
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
 # Create an all-zero matrix of shape (len(sequences), dimension)
 results = np.zeros((len(sequences), dimension))
 for i, sequence in enumerate(sequences):
  results[i, sequence] = 1. # set specific indices of results[i] to 1s
 return results


x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data() 
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
    verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
        width_shift_range=5./32,
        height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
     steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
     callbacks=callbacks,
     validation_data=(testX, Y_test),
     validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Heroku云平台上部署Python的Django框架的教程
Apr 20 Python
Python实现多条件筛选目标数据功能【测试可用】
Jun 13 Python
Python+OpenCV目标跟踪实现基本的运动检测
Jul 10 Python
详解django中url路由配置及渲染方式
Feb 25 Python
详解将Pandas中的DataFrame类型转换成Numpy中array类型的三种方法
Jul 06 Python
django 控制页面跳转的例子
Aug 06 Python
Python imread、newaxis用法详解
Nov 04 Python
python机器学习实现决策树
Nov 11 Python
解决Pytorch 加载训练好的模型 遇到的error问题
Jan 10 Python
在pycharm中使用matplotlib.pyplot 绘图时报错的解决
Jun 01 Python
python raise的基本使用
Sep 10 Python
Python虚拟环境virtualenv是如何使用的
Jun 20 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
You might like
PHP函数实现分页含文本分页和数字分页
2014/10/23 PHP
浅谈PHP中单引号和双引号到底有啥区别呢?
2015/03/04 PHP
ThinkPHP框架实现数据增删改
2017/05/07 PHP
Ubuntu 16.04中Laravel5.4升级到5.6的步骤
2018/12/07 PHP
兼容多浏览器的字幕特效Marquee的通用js类
2008/07/20 Javascript
用JavaScript实现单继承和多继承的简单方法
2009/03/29 Javascript
JavaScript修改css样式style动态改变元素样式
2013/12/16 Javascript
jquery操作复选框(checkbox)的12个小技巧总结
2014/02/04 Javascript
原生JS实现响应式瀑布流布局
2015/04/02 Javascript
JavaScript实现的经典文件树菜单效果
2015/09/08 Javascript
jQuery实现默认是闭合的FAQ展开效果菜单
2015/09/14 Javascript
JS实现判断图片是否加载完成的方法分析
2018/07/31 Javascript
[02:06]2018完美世界全国高校联赛秋季赛开始报名(附彩蛋)
2018/09/03 DOTA
[01:19:46]EG vs Secret 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.21.mp4
2020/07/19 DOTA
[01:06:54]DOTA2-DPC中国联赛 正赛 RNG vs Dragon BO3 第一场 1月24日
2021/03/11 DOTA
python使用urllib模块开发的多线程豆瓣小站mp3下载器
2014/01/16 Python
Python操作MySQL模拟银行转账
2018/03/12 Python
使用Python实现从各个子文件夹中复制指定文件的方法
2018/10/25 Python
python使用 request 发送表单数据操作示例
2019/09/25 Python
pytorch实现onehot编码转为普通label标签
2020/01/02 Python
Python2与Python3的区别详解
2020/02/09 Python
python接口自动化之ConfigParser配置文件的使用详解
2020/08/03 Python
如何利用CSS3制作3D效果文字具体实现样式
2013/05/02 HTML / CSS
西海岸男士和男童服装:Johnnie-O
2018/03/15 全球购物
英国现代、当代和设计师家具店:Furntastic
2020/07/18 全球购物
在数据文件自动增长时,自动增长是否会阻塞对文件的更新
2014/05/01 面试题
管理学专业个人求职信范文
2013/09/21 职场文书
慈善晚会策划方案
2014/05/14 职场文书
公司年终奖分配方案
2014/06/16 职场文书
小区的门卫岗位职责
2014/10/01 职场文书
淮阳太昊陵导游词
2015/02/10 职场文书
员工离职通知函
2015/04/25 职场文书
军训通讯稿范文
2015/07/18 职场文书
2015年卫生局工作总结
2015/07/24 职场文书
Nginx代理同域名前后端分离项目的完整步骤
2021/03/31 Servers
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
2021/05/24 Python