使用Keras构造简单的CNN网络实例


Posted in Python onJune 29, 2020

1. 导入各种模块

基本形式为:

import 模块名

from 某个文件 import 某个模块

2. 导入数据(以两类分类问题为例,即numClass = 2)

训练集数据data

可以看到,data是一个四维的ndarray

训练集的标签

3. 将导入的数据转化我keras可以接受的数据格式

keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数

label = np_utils.to_categorical(label, numClass

此时的label变为了如下形式

(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)

4. 建立CNN模型

以下图所示的CNN网络为例

#生成一个model
model = Sequential()
 
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
 
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
 
# layer5-fully connect
model.add(Dense(numClass, init='normal')) 
model.add(Activation('softmax'))

# 
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

5. 开始训练model

利用model.train_on_batch或者model.fit

补充知识:keras 多分类一些函数参数设置

用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28 
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000 
nb_validation_samples = 10000
epochs = 50 
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':
 input_shape = (3, img_width, img_height)
else:
 input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义

比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:

x = Dense(10,activation='softmax')(x)

此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)


model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集

注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(
 rescale=1. / 255,
 shear_range=0.2,
 zoom_range=0.2,
 horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
 train_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical', #多分类问题设为'categorical'
 classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
 )

validation_generator = test_datagen.flow_from_directory(
 validation_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical'
 )

训练和保存模型及权值

model.fit_generator(
  train_generator,
  samples_per_epoch=nb_train_samples,
  nb_epoch=epochs,
  validation_data=validation_generator,
  nb_val_samples=nb_validation_samples
  )

model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束

图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。

此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
 img_addr=input('Please input your image address:')
 if img_addr=="exit":
  break
 else:
  img = load_img(img_addr, False, target_size=(28, 28))
  x = img_to_array(img) / 255.0
  x = np.expand_dims(x, axis=0)
  result = model.predict(x)
  ind=np.argmax(result,1)
  print('this is a ', classes[ind])

以上这篇使用Keras构造简单的CNN网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 出现错误TypeError: ‘NoneType’ object is not iterable解决办法
Jan 12 Python
Python实现pdf文档转txt的方法示例
Jan 19 Python
Python(Django)项目与Apache的管理交互的方法
May 16 Python
python如何生成各种随机分布图
Aug 27 Python
Python Pexpect库的简单使用方法
Jan 29 Python
Python 利用切片从列表中取出一部分使用的方法
Feb 01 Python
Python实现带下标索引的遍历操作示例
May 30 Python
Spring实战之使用util:命名空间简化配置操作示例
Dec 09 Python
从pandas一个单元格的字符串中提取字符串方式
Dec 17 Python
修改Pandas的行或列的名字(重命名)
Dec 18 Python
python给指定csv表格中的联系人群发邮件(带附件的邮件)
Dec 31 Python
Pytorch 实现冻结指定卷积层的参数
Jan 06 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 #Python
Python enumerate() 函数如何实现索引功能
Jun 29 #Python
解决Keras中CNN输入维度报错问题
Jun 29 #Python
Python字符串split及rsplit方法原理详解
Jun 29 #Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 #Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 #Python
使用keras时input_shape的维度表示问题说明
Jun 29 #Python
You might like
PHP日期时间函数的高级应用技巧
2009/05/16 PHP
PHP 处理图片的类实现代码
2009/10/23 PHP
thinkphp的c方法使用示例
2014/02/24 PHP
ThinkPHP使用心得分享-ThinkPHP + Ajax 实现2级联动下拉菜单
2014/05/15 PHP
浅谈php中urlencode与rawurlencode的区别
2016/09/05 PHP
Laravel框架实现简单的学生信息管理平台案例
2019/05/07 PHP
JS+CSS实现自动改变切换方向图片幻灯切换效果的方法
2015/03/02 Javascript
jQuery网页选项卡插件rTabs用法实例分析
2015/08/26 Javascript
jQuery1.9.1源码分析系列(十六)ajax之ajax框架
2015/12/04 Javascript
JavaScript中的操作符类型转换示例总结
2016/05/30 Javascript
如何通过非数字与字符的方式实现PHP WebShell详解
2017/07/02 Javascript
Vue 项目代理设置的优化
2018/04/17 Javascript
解决vue-cli项目webpack打包后iconfont文件路径的问题
2018/09/01 Javascript
vue 基于element-ui 分页组件封装的实例代码
2018/12/10 Javascript
详解vue-cli 2.0配置文件(小结)
2019/01/14 Javascript
JS实现滑动拼图验证功能完整示例
2020/03/29 Javascript
vue实现登录拦截
2020/06/29 Javascript
py中的目录与文件判别代码
2008/07/16 Python
python下载图片实现方法(超简单)
2017/07/21 Python
Python 互换字典的键值对实例
2019/02/12 Python
详解Python读取yaml文件多层菜单
2019/03/23 Python
windows环境中利用celery实现简单任务队列过程解析
2019/11/29 Python
使用pygame编写Flappy bird小游戏
2020/03/14 Python
Python如何输出百分比
2020/07/31 Python
香港莎莎官网Sasa.com:亚洲著名国际化妆品商城
2019/11/10 全球购物
广州御银科技股份有限公司试卷(C++)
2016/11/04 面试题
杭州时比特电子有限公司SQL
2013/08/22 面试题
《猴子种树》教学反思
2014/02/14 职场文书
暑期培训随笔感言
2014/03/10 职场文书
毕业晚会主持词
2014/03/24 职场文书
高等教育专业自荐信范文
2014/03/26 职场文书
美容院经理岗位职责
2014/04/03 职场文书
城管个人总结
2015/02/28 职场文书
校运会通讯稿
2015/07/18 职场文书
“学党章、守党纪、讲党规”学习心得体会
2016/01/14 职场文书
PHP控制循环操作的时间
2021/04/01 PHP