使用Keras构造简单的CNN网络实例


Posted in Python onJune 29, 2020

1. 导入各种模块

基本形式为:

import 模块名

from 某个文件 import 某个模块

2. 导入数据(以两类分类问题为例,即numClass = 2)

训练集数据data

可以看到,data是一个四维的ndarray

训练集的标签

3. 将导入的数据转化我keras可以接受的数据格式

keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数

label = np_utils.to_categorical(label, numClass

此时的label变为了如下形式

(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)

4. 建立CNN模型

以下图所示的CNN网络为例

#生成一个model
model = Sequential()
 
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
 
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
 
# layer5-fully connect
model.add(Dense(numClass, init='normal')) 
model.add(Activation('softmax'))

# 
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

5. 开始训练model

利用model.train_on_batch或者model.fit

补充知识:keras 多分类一些函数参数设置

用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28 
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000 
nb_validation_samples = 10000
epochs = 50 
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':
 input_shape = (3, img_width, img_height)
else:
 input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义

比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:

x = Dense(10,activation='softmax')(x)

此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)


model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集

注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(
 rescale=1. / 255,
 shear_range=0.2,
 zoom_range=0.2,
 horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
 train_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical', #多分类问题设为'categorical'
 classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
 )

validation_generator = test_datagen.flow_from_directory(
 validation_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical'
 )

训练和保存模型及权值

model.fit_generator(
  train_generator,
  samples_per_epoch=nb_train_samples,
  nb_epoch=epochs,
  validation_data=validation_generator,
  nb_val_samples=nb_validation_samples
  )

model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束

图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。

此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
 img_addr=input('Please input your image address:')
 if img_addr=="exit":
  break
 else:
  img = load_img(img_addr, False, target_size=(28, 28))
  x = img_to_array(img) / 255.0
  x = np.expand_dims(x, axis=0)
  result = model.predict(x)
  ind=np.argmax(result,1)
  print('this is a ', classes[ind])

以上这篇使用Keras构造简单的CNN网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python抓取Discuz!用户名脚本代码
Dec 30 Python
Python多线程threading和multiprocessing模块实例解析
Jan 29 Python
Django中cookie的基本使用方法示例
Feb 03 Python
TensorFlow saver指定变量的存取
Mar 10 Python
django+mysql的使用示例
Nov 23 Python
python 一个figure上显示多个图像的实例
Jul 08 Python
Python一键查找iOS项目中未使用的图片、音频、视频资源
Aug 12 Python
python 生成器和迭代器的原理解析
Oct 12 Python
在Python中使用K-Means聚类和PCA主成分分析进行图像压缩
Apr 10 Python
Python3以GitHub为例来实现模拟登录和爬取的实例讲解
Jul 30 Python
python RSA加密的示例
Dec 09 Python
Python基础之教你怎么在M1系统上使用pandas
May 08 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 #Python
Python enumerate() 函数如何实现索引功能
Jun 29 #Python
解决Keras中CNN输入维度报错问题
Jun 29 #Python
Python字符串split及rsplit方法原理详解
Jun 29 #Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 #Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 #Python
使用keras时input_shape的维度表示问题说明
Jun 29 #Python
You might like
php使用cookie实现记住登录状态
2015/04/27 PHP
Yii2配置Nginx伪静态的方法
2017/05/05 PHP
PHP实现打包zip并下载功能
2018/06/12 PHP
php 命名空间(namespace)原理与用法实例小结
2019/11/13 PHP
Iframe自适应高度绝对好使的代码 兼容IE,遨游,火狐
2011/01/27 Javascript
js 实现图片预加载(js操作 Image对象属性complete ,事件onload 异步加载图片)
2011/03/25 Javascript
jQuery 在光标定位的地方插入文字的插件
2012/05/10 Javascript
Jquery选择子控件"大于号"和" "区别介绍及使用示例
2013/06/25 Javascript
Javascript自定义函数判断网站访问类型是PC还是移动终端
2014/01/10 Javascript
node.js中的fs.futimesSync方法使用说明
2014/12/17 Javascript
node-webkit打包成exe文件被360误报木马的解决方法
2015/03/11 Javascript
JavaScript在Android的WebView中parseInt函数转换不正确问题解决方法
2015/04/25 Javascript
AngularJS在IE8的不支持的解决方法
2016/05/13 Javascript
全面了解JavaScript的数据类型转换
2016/07/01 Javascript
浅谈移动端之js touch事件 手势滑动事件
2016/11/07 Javascript
Javarscript中模块(module)、加载(load)与捆绑(bundle)详解
2017/05/28 Javascript
JavaScript之浏览器对象_动力节点Java学院整理
2017/07/03 Javascript
基于JavaScript实现百度搜索框效果
2020/06/28 Javascript
jquery 获取索引值在一定范围的列表方法
2018/01/25 jQuery
Vue在页面右上角实现可悬浮/隐藏的系统菜单
2018/05/04 Javascript
element-ui中的select下拉列表设置默认值方法
2018/08/24 Javascript
JS监听滚动和id自动定位滚动
2018/12/18 Javascript
浅谈JavaScript面向对象--继承
2019/03/20 Javascript
使用jQuery如何写一个含验证码的登录界面
2019/05/13 jQuery
JS实现普通轮播图特效
2020/01/01 Javascript
js this 绑定机制深入详解
2020/04/30 Javascript
python并发编程 Process对象的其他属性方法join方法详解
2019/08/20 Python
python dataframe NaN处理方式
2019/12/26 Python
python使用OpenCV模块实现图像的融合示例代码
2020/04/10 Python
Pycharm plot独立窗口显示的操作
2020/12/11 Python
美国玛丽莎收藏奢华时尚商店:Marissa Collections
2016/11/21 全球购物
泰国演唱会订票网站:StubHub泰国
2018/02/26 全球购物
英国大码女性时装零售商:Evans
2018/08/29 全球购物
解除劳动合同协议书范本
2014/09/13 职场文书
上班迟到检讨书范文300字
2014/11/02 职场文书
关于感恩的作文
2019/08/26 职场文书