keras实现VGG16方式(预测一张图片)


Posted in Python onJuly 07, 2020

我就废话不多说了,大家还是直接看代码吧~

from keras.applications.vgg16 import VGG16#直接导入已经训练好的VGG16网络
from keras.preprocessing.image import load_img#load_image作用是载入图片
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions
 
model = VGG16()
image = load_img('D:\\photo\\dog.jpg',target_size=(224,224))#参数target_size用于设置目标的大小,如此一来无论载入的原图像大小如何,都会被标准化成统一的大小,这样做是为了向神经网络中方便地输入数据所需的。
image = img_to_array(image)#函数img_to_array会把图像中的像素数据转化成NumPy中的array,这样数据才可以被Keras所使用。
#神经网络接收一张或多张图像作为输入,也就是说,输入的array需要有4个维度: samples, rows, columns, and channels。由于我们仅有一个 sample(即一张image),我们需要对这个array进行reshape操作。
image = image.reshape((1,image.shape[0],image.shape[1],image.shape[2]))
image = preprocess_input(image)#对图像进行预处理
y = model.predict(image)#预测图像的类别
label = decode_predictions(y)#Keras提供了一个函数decode_predictions(),用以对已经得到的预测向量进行解读。该函数返回一个类别列表,以及类别中每个类别的预测概率,
label = label[0][0]
print('%s(%.2f%%)'%(label[1],label[2]*100))
# print(model.summary())
from keras.models import Sequential
from keras.layers.core import Flatten,Dense,Dropout
from keras.layers.convolutional import Convolution2D,MaxPooling2D,ZeroPadding2D
from keras.optimizers import SGD
import numpy as np
 
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
import time
from keras import backend as K
K.set_image_dim_ordering('th')
def VGG_16(weights_path=None):
  model = Sequential()
 
  model.add(ZeroPadding2D((1, 1), input_shape=(3, 224, 224)))
  model.add(Convolution2D(64, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(64, (3, 3), activation='relu'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))
 
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(128, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(128, (3, 3), activation='relu'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))
 
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(256, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(256, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(256, (3, 3), activation='relu'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))
 
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, (3, 3), activation='relu'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))
 
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, (3, 3), activation='relu'))
  model.add(ZeroPadding2D((1, 1)))
  model.add(Convolution2D(512, (3, 3), activation='relu'))
  model.add(MaxPooling2D((2, 2), strides=(2, 2)))
 
  model.add(Flatten())
  model.add(Dense(4096, activation='relu'))
  model.add(Dropout(0.5))
  model.add(Dense(4096, activation='relu'))
  model.add(Dropout(0.5))
  model.add(Dense(1000, activation='softmax'))
 
  if weights_path:
    model.load_weights(weights_path,by_name=True)
 
  return model
 
model = VGG_16(weights_path='F:\\Kaggle\\vgg16_weights.h5')
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy')
 
t0 = time.time()
img = image.load_img('D:\\photo\\dog.jpg', target_size=(224, 224))
x = image.img_to_array(img) # 三维(224,224,3)
x = np.expand_dims(x, axis=0) # 四维(1,224,224,3)#因为keras要求的维度是这样的,所以要增加一个维度
x = preprocess_input(x) # 预处理
print(x.shape)
y_pred = model.predict(x) # 预测概率
 
t1 = time.time()
 
print("测试图:", decode_predictions(y_pred)) # 输出五个最高概率(类名, 语义概念, 预测概率)
print("耗时:", str((t1 - t0) * 1000), "ms")

这是两种不同的方式,第一种是直接使用vgg16的参数,需要在运行时下载,第二种是我们已经下载好的权重,直接在参数中输入我们的路径即可。

补充知识:keras加经典网络的预训练模型(以VGG16为例)

我就废话不多说了,大家还是直接看代码吧~

# 使用VGG16模型
from keras.applications.vgg16 import VGG16
print('Start build VGG16 -------')
 
# 获取vgg16的卷积部分,如果要获取整个vgg16网络需要设置:include_top=True
model_vgg16_conv = VGG16(weights='imagenet', include_top=False)
model_vgg16_conv.summary()
 
# 创建自己的输入格式
# if K.image_data_format() == 'channels_first':
#  input_shape = (3, img_width, img_height)
# else:
#  input_shape = (img_width, img_height, 3)
 
input = Input(input_shape, name = 'image_input') # 注意,Keras有个层就是Input层
 
# 将vgg16模型原始输入转换成自己的输入
output_vgg16_conv = model_vgg16_conv(input)
 
# output_vgg16_conv是包含了vgg16的卷积层,下面我需要做二分类任务,所以需要添加自己的全连接层
x = Flatten(name='flatten')(output_vgg16_conv)
x = Dense(4096, activation='relu', name='fc1')(x)
x = Dense(512, activation='relu', name='fc2')(x)
x = Dense(128, activation='relu', name='fc3')(x)
x = Dense(1, activation='softmax', name='predictions')(x)
 
# 最终创建出自己的vgg16模型
my_model = Model(input=input, output=x)
 
# 下面的模型输出中,vgg16的层和参数不会显示出,但是这些参数在训练的时候会更改
print('\nThis is my vgg16 model for the task')
my_model.summary()

以上这篇keras实现VGG16方式(预测一张图片)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python过滤字符串中不属于指定集合中字符的类实例
Jun 30 Python
Python django实现简单的邮件系统发送邮件功能
Jul 14 Python
Python3中的列表生成式、生成器与迭代器实例详解
Jun 11 Python
python3 对list中每个元素进行处理的方法
Jun 29 Python
python多行字符串拼接使用小括号的方法
Mar 19 Python
在python中安装basemap的教程
Sep 20 Python
Python实现分段线性插值
Dec 17 Python
Django框架模板注入操作示例【变量传递到模板】
Dec 19 Python
通过PYTHON来实现图像分割详解
Jun 26 Python
python绘制双Y轴折线图以及单Y轴双变量柱状图的实例
Jul 08 Python
Django用户认证系统 User对象解析
Aug 02 Python
Python语言内置数据类型
Feb 24 Python
通过实例解析Python RPC实现原理及方法
Jul 07 #Python
Keras预训练的ImageNet模型实现分类操作
Jul 07 #Python
Scrapy模拟登录赶集网的实现代码
Jul 07 #Python
scrapy框架携带cookie访问淘宝购物车功能的实现代码
Jul 07 #Python
Keras构建神经网络踩坑(解决model.predict预测值全为0.0的问题)
Jul 07 #Python
浅谈django框架集成swagger以及自定义参数问题
Jul 07 #Python
Django REST Swagger实现指定api参数
Jul 07 #Python
You might like
PHP 字符串加密函数(在指定时间内加密还原字符串,超时无法还原)
2010/04/28 PHP
PHP自动重命名文件实现方法
2014/11/04 PHP
jquery插件 autoComboBox 下拉框
2010/12/22 Javascript
前台js改变Session的值(用ajax实现)
2012/12/28 Javascript
js 火狐下取本地路径实现思路
2013/04/02 Javascript
AngularJS的内置过滤器详解
2015/05/14 Javascript
详解AngularJS中module模块的导入导出
2015/12/10 Javascript
详解Bootstrap的iCheck插件checkbox和radio
2016/08/24 Javascript
JavaScript lodash常见用法系列小结
2016/08/24 Javascript
详解Vue.js组件可复用性的混合(mixin)方式和自定义指令
2017/09/06 Javascript
Javascript刷新页面的实例
2017/09/23 Javascript
vue2.x集成百度UEditor富文本编辑器的方法
2018/09/21 Javascript
python ElementTree 基本读操作示例
2009/04/09 Python
Python专用方法与迭代机制实例分析
2014/09/15 Python
用Python编写简单的定时器的方法
2015/05/02 Python
python通过ftplib登录到ftp服务器的方法
2015/05/08 Python
python互斥锁、加锁、同步机制、异步通信知识总结
2018/02/11 Python
pandas.dataframe按行索引表达式选取方法
2018/10/30 Python
对python PLT中的image和skimage处理图片方法详解
2019/01/10 Python
使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
2020/01/18 Python
Skip Hop官网:好莱坞宝宝挚爱品牌
2018/06/17 全球购物
美国亚洲时尚和美容产品的一站式网上商店:Stylevana
2019/09/05 全球购物
什么是serialVersionUID
2016/03/04 面试题
金融学专科生自我鉴定
2014/02/21 职场文书
中文教师求职信
2014/02/22 职场文书
推荐信模板
2014/05/09 职场文书
企业文化标语大全
2014/06/10 职场文书
会计人员演讲稿
2014/09/11 职场文书
银行党员批评与自我批评
2014/10/15 职场文书
作风建设整改方案
2014/10/27 职场文书
2014年项目经理工作总结
2014/11/24 职场文书
安全守法证明
2015/06/23 职场文书
初二物理教学反思
2016/02/19 职场文书
英语版自我评价,35句话轻松搞定
2019/10/08 职场文书
Nginx同一个域名配置多个项目的实现方法
2021/03/31 Servers
详解Python 3.10 中的新功能和变化
2021/04/28 Python