使用Keras构造简单的CNN网络实例


Posted in Python onJune 29, 2020

1. 导入各种模块

基本形式为:

import 模块名

from 某个文件 import 某个模块

2. 导入数据(以两类分类问题为例,即numClass = 2)

训练集数据data

可以看到,data是一个四维的ndarray

训练集的标签

3. 将导入的数据转化我keras可以接受的数据格式

keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数

label = np_utils.to_categorical(label, numClass

此时的label变为了如下形式

(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)

4. 建立CNN模型

以下图所示的CNN网络为例

#生成一个model
model = Sequential()
 
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
 
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
 
# layer5-fully connect
model.add(Dense(numClass, init='normal')) 
model.add(Activation('softmax'))

# 
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

5. 开始训练model

利用model.train_on_batch或者model.fit

补充知识:keras 多分类一些函数参数设置

用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28 
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000 
nb_validation_samples = 10000
epochs = 50 
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':
 input_shape = (3, img_width, img_height)
else:
 input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义

比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:

x = Dense(10,activation='softmax')(x)

此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)


model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集

注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(
 rescale=1. / 255,
 shear_range=0.2,
 zoom_range=0.2,
 horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
 train_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical', #多分类问题设为'categorical'
 classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
 )

validation_generator = test_datagen.flow_from_directory(
 validation_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical'
 )

训练和保存模型及权值

model.fit_generator(
  train_generator,
  samples_per_epoch=nb_train_samples,
  nb_epoch=epochs,
  validation_data=validation_generator,
  nb_val_samples=nb_validation_samples
  )

model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束

图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。

此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
 img_addr=input('Please input your image address:')
 if img_addr=="exit":
  break
 else:
  img = load_img(img_addr, False, target_size=(28, 28))
  x = img_to_array(img) / 255.0
  x = np.expand_dims(x, axis=0)
  result = model.predict(x)
  ind=np.argmax(result,1)
  print('this is a ', classes[ind])

以上这篇使用Keras构造简单的CNN网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python自动化运维之IP地址处理模块详解
Dec 10 Python
python中数据爬虫requests库使用方法详解
Feb 11 Python
python中abs&map&reduce简介
Feb 20 Python
Python DataFrame设置/更改列表字段/元素类型的方法
Jun 09 Python
Python实现基于KNN算法的笔迹识别功能详解
Jul 09 Python
Python实现的多叉树寻找最短路径算法示例
Jul 30 Python
浅谈python str.format与制表符\t关于中文对齐的细节问题
Jan 14 Python
python re模块匹配贪婪和非贪婪模式详解
Feb 11 Python
Python在终端通过pip安装好包以后在Pycharm中依然无法使用的问题(三种解决方案)
Mar 10 Python
Python网络爬虫四大选择器用法原理总结
Jun 01 Python
PyQT5 实现快捷键复制表格数据的方法示例
Jun 19 Python
pyx文件 生成pyd 文件用于 cython调用的实现
Mar 04 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 #Python
Python enumerate() 函数如何实现索引功能
Jun 29 #Python
解决Keras中CNN输入维度报错问题
Jun 29 #Python
Python字符串split及rsplit方法原理详解
Jun 29 #Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 #Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 #Python
使用keras时input_shape的维度表示问题说明
Jun 29 #Python
You might like
PHP面向对象程序设计组合模式与装饰模式详解
2016/12/02 PHP
php根据用户名和手机号查询是否存在手机号码
2017/02/16 PHP
asp 取文本框名称代码
2008/12/02 Javascript
深入解析contentWindow, contentDocument
2013/07/04 Javascript
js调用css属性写法
2013/09/21 Javascript
jtable列中自定义button示例代码
2013/11/21 Javascript
JavaScript中字面量与函数的基本使用知识
2015/10/20 Javascript
浅谈Angular2 ng-content 指令在组件中嵌入内容
2017/08/18 Javascript
详解基于Vue2.0实现的移动端弹窗(Alert, Confirm, Toast)组件
2018/08/02 Javascript
javascript中如何判断类型汇总
2019/05/14 Javascript
vue-router的两种模式的区别
2019/05/30 Javascript
vue项目在线上服务器访问失败原因分析
2020/08/14 Javascript
[01:03:18]DOTA2-DPC中国联赛 正赛 RNG vs Dynasty BO3 第一场 1月29日
2021/03/11 DOTA
Python入门篇之字符串
2014/10/17 Python
Python中property属性实例解析
2018/02/10 Python
python如何为创建大量实例节省内存
2018/03/20 Python
Python之批量创建文件的实例讲解
2018/05/10 Python
Python日志无延迟实时写入的示例
2019/07/11 Python
python scrapy爬虫代码及填坑
2019/08/12 Python
Python 基于wxpy库实现微信添加好友功能(简洁)
2019/11/29 Python
解决pip安装的第三方包在PyCharm无法导入的问题
2020/10/15 Python
Orlebar Brown官网:设计师泳裤和泳装
2020/12/08 全球购物
简单的JAVA编程面试题
2013/03/19 面试题
竞选演讲稿范文
2013/12/28 职场文书
项目管理计划书
2014/01/09 职场文书
推荐信怎么写
2014/05/09 职场文书
卖车协议书范例
2014/09/16 职场文书
公务员四风问题对照检查材料整改措施
2014/09/26 职场文书
优秀护士事迹材料
2014/12/25 职场文书
李白故里导游词
2015/02/12 职场文书
2015年医德医风工作总结
2015/04/02 职场文书
交通处罚决定书
2015/06/24 职场文书
CSS中em的正确打开方式详解
2021/04/08 HTML / CSS
详解Redis实现限流的三种方式
2021/04/27 Redis
详解Java实践之建造者模式
2021/06/18 Java/Android
Python字符串常规操作小结
2022/04/03 Python