Keras使用ImageNet上预训练的模型方式


Posted in Python onMay 23, 2020

我就废话不多说了,大家还是直接看代码吧!

import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
 
#Load the VGG model
vgg_model = vgg16.VGG16(weights='imagenet')
 
#Load the Inception_V3 model
inception_model = inception_v3.InceptionV3(weights='imagenet')
 
#Load the ResNet50 model
resnet_model = resnet50.ResNet50(weights='imagenet')
 
#Load the MobileNet model
mobilenet_model = mobilenet.MobileNet(weights='imagenet')

在以上代码中,我们首先import各种模型对应的module,然后load模型,并用ImageNet的参数初始化模型的参数。

如果不想使用ImageNet上预训练到的权重初始话模型,可以将各语句的中'imagenet'替换为'None'。

补充知识:keras上使用alexnet模型来高准确度对mnist数据进行分类

纲要

本文有两个特点:一是直接对本地mnist数据进行读取(假设事先已经下载或从别处拷来)二是基于keras框架(网上多是基于tf)使用alexnet对mnist数据进行分类,并获得较高准确度(约为98%)

本地数据读取和分析

很多代码都是一开始简单调用一行代码来从网站上下载mnist数据,虽然只有10来MB,但是现在下载速度非常慢,而且经常中途出错,要费很大的劲才能拿到数据。

(X_train, y_train), (X_test, y_test) = mnist.load_data()

其实可以单独来获得这些数据(一共4个gz包,如下所示),然后调用别的接口来分析它们。

Keras使用ImageNet上预训练的模型方式

mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录

x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels

这里面要注意的是,两种接口拿到的数据形式是不一样的。 从网上直接下载下来的数据 其image data值的范围是0~255,且label值为0,1,2,3...9。 而第二种接口获取的数据 image值已经除以255(归一化)变成0~1范围,且label值已经是one-hot形式(one_hot=True时),比如label值2的one-hot code为(0 0 1 0 0 0 0 0 0 0)

所以,以第一种方式获取的数据需要做一些预处理(归一和one-hot)才能输入网络模型进行训练 而第二种接口拿到的数据则可以直接进行训练。

Alexnet模型的微调

按照公开的模型框架,Alexnet只有第1、2个卷积层才跟着BatchNormalization,后面三个CNN都没有(如有说错,请指正)。如果按照这个来搭建网络模型,很容易导致梯度消失,现象就是 accuracy值一直处在很低的值。 如下所示。

Keras使用ImageNet上预训练的模型方式

在每个卷积层后面都加上BN后,准确度才迭代提高。如下所示

Keras使用ImageNet上预训练的模型方式

完整代码

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import ModelCheckpoint
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #tensorflow已经包含了mnist案例的数据
 
batch_size = 64
num_classes = 10
epochs = 10
img_shape = (28,28,1)
 
# input dimensions
img_rows, img_cols = 28,28
 
# dataset input
#(x_train, y_train), (x_test, y_test) = mnist.load_data()
mnist = input_data.read_data_sets("./MNIST_data", one_hot = True) #导入已经下载好的数据集,"./MNIST_data"为存放mnist数据的目录
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
print(mnist.validation.images.shape, mnist.validation.labels.shape)
 
x_train = mnist.train.images
y_train = mnist.train.labels
x_test = mnist.test.images
y_test = mnist.test.labels
 
# data initialization
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
 
# Define the input layer
inputs = keras.Input(shape = [img_rows, img_cols, 1])
 
 #Define the converlutional layer 1
conv1 = keras.layers.Conv2D(filters= 64, kernel_size= [11, 11], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(inputs)
# Define the pooling layer 1
pooling1 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv1)
# Define the standardization layer 1
stand1 = keras.layers.BatchNormalization(axis= 1)(pooling1)
 
# Define the converlutional layer 2
conv2 = keras.layers.Conv2D(filters= 192, kernel_size= [5, 5], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand1)
# Defien the pooling layer 2
pooling2 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv2)
# Define the standardization layer 2
stand2 = keras.layers.BatchNormalization(axis= 1)(pooling2)
 
# Define the converlutional layer 3
conv3 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand2)
stand3 = keras.layers.BatchNormalization(axis=1)(conv3)
 
# Define the converlutional layer 4
conv4 = keras.layers.Conv2D(filters= 384, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand3)
stand4 = keras.layers.BatchNormalization(axis=1)(conv4)
 
# Define the converlutional layer 5
conv5 = keras.layers.Conv2D(filters= 256, kernel_size= [3, 3], strides= [1, 1], activation= keras.activations.relu, use_bias= True, padding= 'same')(stand4)
pooling5 = keras.layers.AveragePooling2D(pool_size= [2, 2], strides= [2, 2], padding= 'valid')(conv5)
stand5 = keras.layers.BatchNormalization(axis=1)(pooling5)
 
# Define the fully connected layer
flatten = keras.layers.Flatten()(stand5)
fc1 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(flatten)
drop1 = keras.layers.Dropout(0.5)(fc1)
 
fc2 = keras.layers.Dense(4096, activation= keras.activations.relu, use_bias= True)(drop1)
drop2 = keras.layers.Dropout(0.5)(fc2)
 
fc3 = keras.layers.Dense(10, activation= keras.activations.softmax, use_bias= True)(drop2)
 
# 基于Model方法构建模型
model = keras.Model(inputs= inputs, outputs = fc3)
# 编译模型
model.compile(optimizer= tf.train.AdamOptimizer(0.001),
       loss= keras.losses.categorical_crossentropy,
       metrics= ['accuracy'])
# 训练配置,仅供参考
model.fit(x_train, y_train, batch_size= batch_size, epochs= epochs, validation_data=(x_test,y_test))

以上这篇Keras使用ImageNet上预训练的模型方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用scapy模拟数据包实现arp攻击、dns放大攻击例子
Oct 23 Python
python 列表,数组,矩阵两两转换tolist()的实例
Apr 04 Python
Python图像处理之图像的读取、显示与保存操作【测试可用】
Jan 04 Python
Python正则表达式匹配日期与时间的方法
Jul 07 Python
python单例模式的多种实现方法
Jul 26 Python
浅谈pytorch grad_fn以及权重梯度不更新的问题
Aug 20 Python
Python 合并多个TXT文件并统计词频的实现
Aug 23 Python
python求质数列表的例子
Nov 24 Python
python本地文件服务器实例教程
May 02 Python
Python面向对象之成员相关知识总结
Jun 24 Python
Python办公自动化PPT批量转换操作
Sep 15 Python
教你使用Python获取QQ音乐某个歌手的歌单
Apr 03 Python
使用Keras预训练模型ResNet50进行图像分类方式
May 23 #Python
基于Python中random.sample()的替代方案
May 23 #Python
keras 自定义loss损失函数,sample在loss上的加权和metric详解
May 23 #Python
keras中模型训练class_weight,sample_weight区别说明
May 23 #Python
浅谈keras中的Merge层(实现层的相加、相减、相乘实例)
May 23 #Python
Keras实现将两个模型连接到一起
May 23 #Python
keras 获取某层输出 获取复用层的多次输出实例
May 23 #Python
You might like
PHP实现生成透明背景的PNG缩略图函数分享
2014/07/08 PHP
jQuery获取json后使用zy_tmpl生成下拉菜单
2015/03/27 PHP
使用 laravel sms 构建短信验证码发送校验功能
2017/11/06 PHP
PHP实现链式操作的三种方法详解
2017/11/16 PHP
JS基于Ajax实现的网页Loading效果代码
2015/10/27 Javascript
Bootstrap模态框调用功能实现方法
2016/09/19 Javascript
Vue.Draggable实现拖拽效果
2020/07/29 Javascript
详解vue-router和vue-cli以及组件之间的传值
2017/07/04 Javascript
Node.js学习教程之HTTP/2服务器推送【译】
2017/10/31 Javascript
vue项目中用cdn优化的方法
2018/01/03 Javascript
JS实现为动态创建的元素添加事件操作示例
2018/03/17 Javascript
Vue项目中添加锁屏功能实现思路
2018/06/29 Javascript
vue input实现点击按钮文字增删功能示例
2019/01/29 Javascript
微信小程序的tab选项卡的实现效果
2019/05/15 Javascript
Vue 实现简易多行滚动"弹幕"效果
2020/01/02 Javascript
[36:45]TNC vs VGJ.S 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
[56:17]NB vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第三场 8.22
2019/09/05 DOTA
利用Python爬取可用的代理IP
2016/08/18 Python
Python根据当前日期取去年同星期日期
2019/04/14 Python
python3的print()函数的用法图文讲解
2019/07/16 Python
django重新生成数据库中的某张表方法
2019/08/28 Python
用python监控服务器的cpu,磁盘空间,内存,超过邮件报警
2021/01/29 Python
详解使用postMessage解决iframe跨域通信问题
2019/11/01 HTML / CSS
SheIn俄罗斯:时尚女装网上商店
2017/02/28 全球购物
请描述一下”is a”关系和”has a”关系
2015/02/03 面试题
优秀研究生自我鉴定
2013/12/04 职场文书
会计辞职信范文
2014/01/15 职场文书
初一地理教学反思
2014/01/16 职场文书
海洋科学专业求职信
2014/08/10 职场文书
农业局党的群众路线教育实践活动整改方案
2014/09/20 职场文书
2015年清明节演讲稿范文
2015/03/17 职场文书
民间借贷借条范本
2015/05/25 职场文书
保护动物的宣传语
2015/07/13 职场文书
庭外和解协议书
2016/03/23 职场文书
2019年让高校“心动”的自荐信
2019/03/25 职场文书
Python趣味挑战之给幼儿园弟弟生成1000道算术题
2021/05/28 Python