Keras使用ImageNet上预训练的模型方式


Posted in Python onMay 23, 2020

我就废话不多说了,大家还是直接看代码吧!

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
#Load the VGG model
vgg_model = vgg16.VGG16(weights='imagenet')
 
#Load the Inception_V3 model
inception_model = inception_v3.InceptionV3(weights='imagenet')
 
#Load the ResNet50 model
resnet_model = resnet50.ResNet50(weights='imagenet')
 
#Load the MobileNet model
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

Keras使用ImageNet上预训练的模型方式

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

Keras使用ImageNet上预训练的模型方式

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

Keras使用ImageNet上预训练的模型方式

完整代码

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据
 
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
 
# input dimensions
img_rows, img_cols = 28,28
 
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
 
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
 
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
 
 #Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
 
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
 
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
 
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
 
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
 
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
 
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
 
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
 
# 基于Model方法构建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 编译模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
       loss= keras.losses.categorical_crossentropy,
       metrics= ['accuracy'])
# 训练配置,仅供参考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python格式化字符串实例总结
Sep 28 Python
详细解析Python中的变量的数据类型
May 13 Python
python爬虫headers设置后无效的解决方法
Oct 21 Python
CentOS6.9 Python环境配置(python2.7、pip、virtualenv)
May 06 Python
Pandas0.25来了千万别错过这10大好用的新功能
Aug 07 Python
pycharm无法导入本地模块的解决方式
Feb 12 Python
用python介绍4种常用的单链表翻转的方法小结
Feb 24 Python
Python 调用有道翻译接口实现翻译
Mar 02 Python
Django封装交互接口代码
Jul 12 Python
Python自动登录QQ的实现示例
Aug 28 Python
Python命令行参数定义及需要注意的地方
Nov 30 Python
Python爬虫基础之简单说一下scrapy的框架结构
Jun 26 Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
Keras实现将两个模型连接到一起
May 23 #Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
You might like
PHP实现搜索地理位置及计算两点地理位置间距离的实例
2016/01/08 PHP
PHP采用超长(超大)数字运算防止数字以科学计数法显示的方法
2016/04/01 PHP
php 多文件上传的实现实例
2016/10/23 PHP
JS处理VBArray的函数使用说明
2008/05/11 Javascript
js原型链原理看图说明
2012/07/07 Javascript
根据选择不同的下拉值出现相对应的文本输入框
2013/08/01 Javascript
jQuery Validate 验证,校验规则写在控件中的具体实例
2014/02/27 Javascript
使用JavaScript脚本无法直接改变Asp.net中Checkbox控件的Enable属性的解决方法
2015/09/16 Javascript
js实现搜索框关键字智能匹配代码
2020/03/26 Javascript
JavaScipt选取文档元素的方法(推荐)
2016/08/05 Javascript
浅谈jQuery中的eq()与DOM中element.[]的区别
2016/10/28 Javascript
JS实现点击网页判断是否安装app并打开否则跳转app store
2016/11/18 Javascript
JavaScript中三种常见的排序方法
2017/02/24 Javascript
CodeMirror js代码加亮使用总结
2017/03/25 Javascript
Windows下使用Nodejs运行js的方法
2017/09/02 NodeJs
vue二级路由设置方法
2018/02/09 Javascript
Python构建XML树结构的方法示例
2017/06/30 Python
Python实现调度算法代码详解
2017/12/01 Python
python实现决策树分类算法
2017/12/21 Python
解决python nohup linux 后台运行输出的问题
2018/05/11 Python
python中tkinter的应用:修改字体的实例讲解
2019/07/17 Python
python实现桌面托盘气泡提示
2019/07/29 Python
matplotlib之属性组合包(cycler)的使用
2021/02/24 Python
纽约现代艺术博物馆商店:MoMA STORE(室内家具和杂货商品)
2016/08/02 全球购物
加拿大便宜的隐形眼镜商店:Clearly
2016/09/15 全球购物
南威尔士家居商店:Leekes
2016/10/25 全球购物
ProBikeKit新西兰:自行车套件,跑步和铁人三项装备
2017/04/05 全球购物
梅西酒窖:Macy’s Wine Cellar
2018/01/07 全球购物
英国在线潜水商店:Simply Scuba
2019/03/25 全球购物
学术会议欢迎词
2014/01/09 职场文书
关于打架的检讨书
2014/01/17 职场文书
庆八一活动方案
2014/01/25 职场文书
会计岗位职责范本
2014/03/07 职场文书
工程学毕业生自荐信
2014/06/14 职场文书
创业计划书之青年旅馆
2019/09/23 职场文书
mysql全面解析json/数组
2022/07/07 MySQL