Keras使用ImageNet上预训练的模型方式


Posted in Python onMay 23, 2020

我就废话不多说了,大家还是直接看代码吧!

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
#Load the VGG model
vgg_model = vgg16.VGG16(weights='imagenet')
 
#Load the Inception_V3 model
inception_model = inception_v3.InceptionV3(weights='imagenet')
 
#Load the ResNet50 model
resnet_model = resnet50.ResNet50(weights='imagenet')
 
#Load the MobileNet model
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

Keras使用ImageNet上预训练的模型方式

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

Keras使用ImageNet上预训练的模型方式

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

Keras使用ImageNet上预训练的模型方式

完整代码

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据
 
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
 
# input dimensions
img_rows, img_cols = 28,28
 
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
 
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
 
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
 
 #Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
 
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
 
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
 
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
 
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
 
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
 
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
 
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
 
# 基于Model方法构建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 编译模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
       loss= keras.losses.categorical_crossentropy,
       metrics= ['accuracy'])
# 训练配置,仅供参考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用 Selenium 实现网页截图实例
Jul 18 Python
详解Python中的文本处理
Apr 11 Python
Python中Django框架下的staticfiles使用简介
May 30 Python
Python脚本实现虾米网签到功能
Apr 12 Python
利用python写个下载teahour音频的小脚本
May 08 Python
Python3实现爬取指定百度贴吧页面并保存页面数据生成本地文档的方法
Apr 22 Python
python程序控制NAO机器人行走
Apr 29 Python
使用selenium模拟登录解决滑块验证问题的实现
May 10 Python
python线程安全及多进程多线程实现方法详解
Sep 27 Python
Python面向对象程序设计之继承、多态原理与用法详解
Mar 23 Python
Python之Matplotlib文字与注释的使用方法
Jun 18 Python
python全栈开发语法总结
Nov 22 Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
Keras实现将两个模型连接到一起
May 23 #Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
You might like
让你的WINDOWS同时支持MYSQL4,MYSQL4.1,MYSQL5X
2006/12/06 PHP
PHP fgetcsv 定义和用法(附windows与linux下兼容问题)
2012/05/29 PHP
浅析虚拟主机服务器php fsockopen函数被禁用的解决办法
2013/08/07 PHP
PHP自动重命名文件实现方法
2014/11/04 PHP
PHP中call_user_func_array回调函数的用法示例
2016/11/26 PHP
PHP写的简单数字验证码实例
2017/05/23 PHP
PHP生成腾讯云COS接口需要的请求签名
2018/05/20 PHP
网站被恶意镜像怎么办 php一段代码轻松搞定(全面版)
2018/10/23 PHP
asp.net和asp下ACCESS的参数化查询
2008/06/11 Javascript
Div Select挡住的解决办法
2008/08/07 Javascript
解决jquery .ajax 在IE下卡死问题的解决方法
2009/10/26 Javascript
关于html+ashx开发中几个问题的解决方法
2011/07/18 Javascript
JQuery 控制内容长度超出规定长度显示省略号
2014/05/23 Javascript
使用jquery实现的一个图片延迟加载插件(含图片延迟加载原理)
2014/06/05 Javascript
JQuery中DOM事件绑定用法详解
2015/06/13 Javascript
基于JavaScript判断浏览器到底是关闭还是刷新(超准确)
2016/02/01 Javascript
HTML5+Canvas调用手机拍照功能实现图片上传(上)
2017/04/21 Javascript
原生js+cookie实现购物车功能的方法分析
2017/12/21 Javascript
vue中锚点的三种方法
2018/07/06 Javascript
vue实现分页组件
2020/06/16 Javascript
开发Node CLI构建微信小程序脚手架的示例
2020/03/27 Javascript
[00:31]2016完美“圣”典风云人物:国士无双宣传片
2016/12/04 DOTA
使用python爬虫获取黄金价格的核心代码
2018/06/13 Python
Python中如何使用if语句处理列表实例代码
2019/02/24 Python
python数据类型可变不可变知识点总结
2020/03/06 Python
Keras:Unet网络实现多类语义分割方式
2020/06/11 Python
Django基于Models定制Admin后台实现过程解析
2020/11/11 Python
python实现不同数据库间数据同步功能
2021/02/25 Python
电子狗项圈:eDog Australia
2019/12/04 全球购物
牵手50新加坡:专为黄金岁月的单身人士而设的交友网站
2020/08/16 全球购物
.NET面试问题集
2015/12/08 面试题
教师党员公开承诺书
2014/03/25 职场文书
新年主持词
2014/03/27 职场文书
软件售后服务承诺书
2014/05/21 职场文书
党员批评与自我批评发言材料
2014/10/14 职场文书
红灯733-1型14管5波段半导体收音机
2021/04/22 无线电