使用Keras构造简单的CNN网络实例


Posted in Python onJune 29, 2020

1. 导入各种模块

基本形式为:

import 模块名

from 某个文件 import 某个模块

2. 导入数据(以两类分类问题为例,即numClass = 2)

训练集数据data

可以看到,data是一个四维的ndarray

训练集的标签

3. 将导入的数据转化我keras可以接受的数据格式

keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数

label = np_utils.to_categorical(label, numClass

此时的label变为了如下形式

(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)

4. 建立CNN模型

以下图所示的CNN网络为例

#生成一个model
model = Sequential()
 
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
 
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
 
# layer5-fully connect
model.add(Dense(numClass, init='normal')) 
model.add(Activation('softmax'))

# 
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

5. 开始训练model

利用model.train_on_batch或者model.fit

补充知识:keras 多分类一些函数参数设置

用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28 
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000 
nb_validation_samples = 10000
epochs = 50 
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':
 input_shape = (3, img_width, img_height)
else:
 input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义

比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:

x = Dense(10,activation='softmax')(x)

此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)


model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集

注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(
 rescale=1. / 255,
 shear_range=0.2,
 zoom_range=0.2,
 horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
 train_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical', #多分类问题设为'categorical'
 classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
 )

validation_generator = test_datagen.flow_from_directory(
 validation_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical'
 )

训练和保存模型及权值

model.fit_generator(
  train_generator,
  samples_per_epoch=nb_train_samples,
  nb_epoch=epochs,
  validation_data=validation_generator,
  nb_val_samples=nb_validation_samples
  )

model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束

图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。

此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
 img_addr=input('Please input your image address:')
 if img_addr=="exit":
  break
 else:
  img = load_img(img_addr, False, target_size=(28, 28))
  x = img_to_array(img) / 255.0
  x = np.expand_dims(x, axis=0)
  result = model.predict(x)
  ind=np.argmax(result,1)
  print('this is a ', classes[ind])

以上这篇使用Keras构造简单的CNN网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
将Python中的数据存储到系统本地的简单方法
Apr 11 Python
关于Python中Inf与Nan的判断问题详解
Feb 08 Python
python中通过预先编译正则表达式提高效率
Sep 25 Python
Python机器学习logistic回归代码解析
Jan 17 Python
Python 反转字符串(reverse)的方法小结
Feb 20 Python
python 利用栈和队列模拟递归的过程
May 29 Python
解决python升级引起的pip执行错误的问题
Jun 12 Python
使用Python AIML搭建聊天机器人的方法示例
Jul 09 Python
对python中的控制条件、循环和跳出详解
Jun 24 Python
Python实现的ftp服务器功能详解【附源码下载】
Jun 26 Python
python+selenium+chromedriver实现爬虫示例代码
Apr 10 Python
Python生成pdf目录书签的实例方法
Oct 29 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 #Python
Python enumerate() 函数如何实现索引功能
Jun 29 #Python
解决Keras中CNN输入维度报错问题
Jun 29 #Python
Python字符串split及rsplit方法原理详解
Jun 29 #Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 #Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 #Python
使用keras时input_shape的维度表示问题说明
Jun 29 #Python
You might like
php pack与unpack 摸板字符字符含义
2009/10/29 PHP
php批量缩放图片的代码[ini参数控制]
2011/02/11 PHP
简单的PHP缓存设计实现代码
2011/09/30 PHP
解析PHP处理换行符的问题 \r\n
2013/06/13 PHP
php打开远程文件的方法和风险及解决方法
2013/11/12 PHP
linux下安装php的memcached客户端
2014/08/03 PHP
php实现文章置顶功能的方法
2016/10/20 PHP
PHP attributes()函数讲解
2019/02/03 PHP
laravel实现按月或天或小时统计mysql数据的方法
2019/10/09 PHP
PHP实现计算器小功能
2020/08/28 PHP
JQuery 动画卷页 返回顶部 动画特效(兼容Chrome)
2010/02/15 Javascript
js实现动态添加、删除行、onkeyup表格求和示例
2013/08/18 Javascript
jquery 追加tr和删除tr示例代码
2013/09/12 Javascript
深入理解JQuery中的事件与动画
2016/05/18 Javascript
jQuery3.0中的buildFragment私有函数详解
2016/08/16 Javascript
JavaScript数组去重由慢到快由繁到简(优化篇)
2016/08/26 Javascript
JS正则表达式判断有效数实例代码
2017/03/13 Javascript
JS实现字符串翻转的方法分析
2018/08/31 Javascript
vue中使用codemirror的实例详解
2018/11/01 Javascript
vue结合el-upload实现腾讯云视频上传功能
2020/07/01 Javascript
JS数据类型判断的几种常用方法
2020/07/07 Javascript
openlayers实现图标拖动获取坐标
2020/09/25 Javascript
[01:34]传奇从这开始 2016国际邀请赛中国区预选赛震撼开启
2016/06/26 DOTA
python实现ipsec开权限实例
2014/11/11 Python
python结合shell查询google关键词排名的实现代码
2016/02/27 Python
对python3 Serial 串口助手的接收读取数据方法详解
2019/06/12 Python
django基于restframework的CBV封装详解
2019/08/08 Python
pygame实现贪吃蛇游戏(下)
2019/10/29 Python
python如何编写类似nmap的扫描工具
2020/11/06 Python
梅西百货澳大利亚:Macy’s Australia
2017/07/26 全球购物
Marc O’Polo俄罗斯官方在线商店:德国高端时尚品牌
2019/12/26 全球购物
大学生最常用的自我评价
2013/12/07 职场文书
房地产广告词大全
2014/03/19 职场文书
上课随便讲话检讨书
2014/09/12 职场文书
2014物价局群众路线对照检查材料思想汇报
2014/09/21 职场文书
新郎结婚感言
2015/07/31 职场文书