使用Keras构造简单的CNN网络实例


Posted in Python onJune 29, 2020

1. 导入各种模块

基本形式为:

import 模块名

from 某个文件 import 某个模块

2. 导入数据(以两类分类问题为例,即numClass = 2)

训练集数据data

可以看到,data是一个四维的ndarray

训练集的标签

3. 将导入的数据转化我keras可以接受的数据格式

keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数

label = np_utils.to_categorical(label, numClass

此时的label变为了如下形式

(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)

4. 建立CNN模型

以下图所示的CNN网络为例

#生成一个model
model = Sequential()
 
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
 
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
 
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
 
# layer5-fully connect
model.add(Dense(numClass, init='normal')) 
model.add(Activation('softmax'))

# 
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")

5. 开始训练model

利用model.train_on_batch或者model.fit

补充知识:keras 多分类一些函数参数设置

用Lenet-5 识别Mnist数据集为例子:

采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。

首先import所需要的模块

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K

定义图像数据信息及训练参数

img_width, img_height = 28, 28 
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000 
nb_validation_samples = 10000
epochs = 50 
batch_size = 32

判断使用的后台

if K.image_dim_ordering() == 'th':
 input_shape = (3, img_width, img_height)
else:
 input_shape = (img_width, img_height, 3)

网络模型定义

主要注意最后的输出层定义

比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:

x = Dense(10,activation='softmax')(x)

此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid

img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)

x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)


model.compile(loss='binary_crossentropy',
    optimizer='rmsprop',
    metrics=['accuracy'])
model.summary()

利用ImageDataGenerator传入图像数据集

注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。

train_datagen = ImageDataGenerator(
 rescale=1. / 255,
 shear_range=0.2,
 zoom_range=0.2,
 horizontal_flip=True)

# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
 train_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical', #多分类问题设为'categorical'
 classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
 )

validation_generator = test_datagen.flow_from_directory(
 validation_data_dir,
 target_size=(img_width, img_height),
 batch_size=batch_size,
 class_mode='categorical'
 )

训练和保存模型及权值

model.fit_generator(
  train_generator,
  samples_per_epoch=nb_train_samples,
  nb_epoch=epochs,
  validation_data=validation_generator,
  nb_val_samples=nb_validation_samples
  )

model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')

至此训练结束

图片预测

注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。

此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。

from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
 img_addr=input('Please input your image address:')
 if img_addr=="exit":
  break
 else:
  img = load_img(img_addr, False, target_size=(28, 28))
  x = img_to_array(img) / 255.0
  x = np.expand_dims(x, axis=0)
  result = model.predict(x)
  ind=np.argmax(result,1)
  print('this is a ', classes[ind])

以上这篇使用Keras构造简单的CNN网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python时区设置方法与pytz查询时区教程
Nov 27 Python
介绍Python中的fabs()方法的使用
May 14 Python
Python实现自动添加脚本头信息的示例代码
Sep 02 Python
疯狂上涨的Python 开发者应从2.x还是3.x着手?
Nov 16 Python
django 2.0更新的10条注意事项总结
Jan 05 Python
怎么使用pipenv管理你的python项目
Mar 12 Python
关于django 数据库迁移(migrate)应该知道的一些事
May 27 Python
django与小程序实现登录验证功能的示例代码
Feb 19 Python
Tensorflow中的降维函数tf.reduce_*使用总结
Apr 20 Python
Python3爬虫中pyspider的安装步骤
Jul 29 Python
python异步的ASGI与Fast Api实现
Jul 16 Python
python垃圾回收机制原理分析
Apr 13 Python
基于K.image_data_format() == 'channels_first' 的理解
Jun 29 #Python
Python enumerate() 函数如何实现索引功能
Jun 29 #Python
解决Keras中CNN输入维度报错问题
Jun 29 #Python
Python字符串split及rsplit方法原理详解
Jun 29 #Python
浅谈Keras参数 input_shape、input_dim和input_length用法
Jun 29 #Python
使用 prometheus python 库编写自定义指标的方法(完整代码)
Jun 29 #Python
使用keras时input_shape的维度表示问题说明
Jun 29 #Python
You might like
php实现设计模式中的单例模式详解
2014/10/11 PHP
YII Framework框架教程之使用YIIC快速创建YII应用详解
2016/03/15 PHP
Laravel路由设定和子路由设定实例分析
2016/03/30 PHP
修改jQuery.Autocomplete插件 支持中文输入法 避免TAB、ENTER键失效、导致表单提交
2009/10/11 Javascript
Javascript Math ceil()、floor()、round()三个函数的区别
2010/03/09 Javascript
父节点获取子节点的字符串示例代码
2014/02/26 Javascript
js实现的点击div区域外隐藏div区域
2014/06/30 Javascript
JavaScript获取图片真实大小代码实例
2014/09/24 Javascript
nodejs开发环境配置与使用
2014/11/17 NodeJs
JavaScritp添加url参数并将参数加入到url中及更改url参数的方法
2015/10/26 Javascript
Active控件问题小结(附解决办法)
2016/06/09 Javascript
浅析script标签中的defer与async属性
2016/11/30 Javascript
JS闭包与延迟求值用法示例
2016/12/22 Javascript
js+html5实现页面可刷新的倒计时效果
2017/07/15 Javascript
JavaScript实现指定数量的并发限制的示例代码
2020/03/10 Javascript
vue-socket.io接收不到数据问题的解决方法
2020/05/13 Javascript
vue全局使用axios的操作
2020/09/08 Javascript
vue使用Sass时报错问题的解决方法
2020/10/14 Javascript
python连接MySQL、MongoDB、Redis、memcache等数据库的方法
2013/11/15 Python
在Python的web框架中中编写日志列表的教程
2015/04/30 Python
Python3网络爬虫开发实战之极验滑动验证码的识别
2019/08/02 Python
python numpy之np.random的随机数函数使用介绍
2019/10/06 Python
Python3 元组tuple入门基础
2020/02/09 Python
python中pyqtgraph知识点总结
2021/01/26 Python
全面解析CSS Media媒体查询使用操作(推荐)
2017/08/15 HTML / CSS
深入浅析css3 中display box使用方法
2015/11/25 HTML / CSS
HTML5学习心得总结(推荐)
2016/07/08 HTML / CSS
美国南部最大的家族百货公司:Belk
2017/01/30 全球购物
THE OUTNET美国官网:国际设计师品牌折扣网站
2017/03/07 全球购物
党校培训思想汇报
2014/01/03 职场文书
在职证明书范本(2014新版)
2014/09/25 职场文书
2014年宣传工作总结
2014/11/18 职场文书
二年级学生期末评语
2014/12/26 职场文书
军训结束新闻稿
2015/07/17 职场文书
《领导干部从政道德启示录》学习心得体会
2016/01/20 职场文书
Pandas数据类型之category的用法
2021/06/28 Python