浅谈keras使用预训练模型vgg16分类,损失和准确度不变


Posted in Python onJuly 02, 2020

问题keras使用预训练模型vgg16分类,损失和准确度不变。

细节:使用keras训练一个两类数据,正负比例1:3,在vgg16后添加了几个全链接并初始化了。并且对所有层都允许训练。

但是准确度一直是0.75.

数据预先处理已经检查过格式正确

再将模型中relu改成sigmoid就正常了。

数据处理程序

import os
import pickle
import numpy as np
 
import DataFile
import SelectiveSearch
import Generator
import IoU
import Model_CRNN_VGG16
 
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
 
def data_generator(gen1,gen0):
 while True:
 data_pos = next(gen1)
 data_neg = next(gen0)
 ret_X = np.vstack((data_pos[0],data_neg[0]))
 ret_y = np.vstack((data_pos[1],data_neg[1]))
 
 index = np.arange(ret_y.shape[0])
 np.random.shuffle(index)
 
 ret_X = ret_X[index, :, :, :] # X_train是训练集,y_train是训练标签
 ret_y = ret_y[index]
 yield ret_X,ret_y
 
if __name__ == "__main__":
 type = "train"
 
 # 数据生成器,每个mini-batch包含32个正样本(属于VOC 20个类别),96个负样本(background)
 RESIZE = (224, 224)
 path = "category_images"
 categories = os.listdir(path)
 categories.append('background')
 print(categories)
 
 train_1_datagen = ImageDataGenerator(
 rescale=1.0/255,
 #shear_range=0.2,
 #zoom_range=0.2,
 horizontal_flip=True)
 
 train_1_generator = train_1_datagen.flow_from_directory(
 'category_images',
 target_size=RESIZE,
 batch_size=32,
 classes = categories)
 
 train_0_datagen = ImageDataGenerator(
 rescale=1.0 / 255,
 #shear_range=0.2,
 #zoom_range=0.2,
 horizontal_flip=True)
 
 train_0_generator = train_0_datagen.flow_from_directory(
 'category_background',
 target_size=RESIZE,
 batch_size=32*3,
 classes=categories)
 
 generator = data_generator(train_1_generator,train_0_generator)
 
 # 创建模型
 model = Model_CRNN_VGG16.CRNN_Model(input_shape=(*RESIZE,3))
 cnn = model.CNN(len(categories))
 if os.path.exists('weights-cnn.hdf5'):
 cnn.load_weights('weights-cnn.hdf5')
 if type == "train":
 checkpoint = ModelCheckpoint('weights-cnn.hdf5',save_weights_only=True)
 cnn.fit_generator(generator = generator,steps_per_epoch=200,epochs=1000,callbacks=[checkpoint])
 else:
 img = next(generator)[0]
 result = cnn.predict(img)
 print(result)
 
 # 训练SVM
 # 非极大值抑制
 # 预测

模型程序:

from keras.applications.vgg16 import VGG16
from keras.layers import *
from keras.models import Model
from keras.optimizers import SGD,Adam
 
class CRNN_Model():
 def __init__(self,input_shape,trainable=True):
 vgg16 = VGG16(include_top=False,weights="imagenet", input_shape=input_shape)
 for layer in vgg16.layers:
  layer.trainable = trainable
 self.base_model = vgg16
 
 def CNN(self,classes):
 img_input = self.base_model.input
 x = self.base_model.get_layer('block5_conv3').output
 
 x = Flatten(name='crnn_flatten')(x)
 
 x = Dense(512,activation='relu', kernel_initializer='he_normal', name='crnn_fc1')(x)
 x = Dense(512,activation='relu', kernel_initializer='he_normal',name='crnn_fc2')(x)
 x = Dense(classes, activation='softmax', kernel_initializer='he_normal', name='crnn_predictions')(x)
 
 model = Model(img_input,x)
 
 sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
 adam = Adam()
 model.compile(optimizer=adam,
   loss='categorical_crossentropy',
   metrics=['accuracy'])
 
 model.summary()
 return model
if __name__ == "__main__":
 pass

补充知识:val_acc一直不变

val_loss一直不变的原因

之前用keras编写了LSTM模型,做图片分类,自己划分了测试集和训练集,但是得到的结果是每个epoch训练的准确率都不变。

浅谈keras使用预训练模型vgg16分类,损失和准确度不变

探索

我一直以为是我的数据的读取方式不对,我一直在从这方面下手,但是后来我发现根本不是这个原因,也找到了解决方案,具体原因有三点,三点是递进关系。

1.数据集样本各类别数量差距大

如果没有这种情况就看看第二点。

2.训练集和数据集是手动划分的,改为代码自动划分

代码如下:

X_train, X_test,Y_train, Y_test = train_test_split(data, labels, test_size=0.4, random_state=42)```

上述方法要多设置几个epoch,要有耐心的等,如果还是测试的准确率还是不变,那就可能是第二个原因。

3. 训练模型不适用,或者模型参数不恰当,建议调参,或者改算法

如果第一个方法还是不行那就可能是算法不适合这个数据集,可以打印混淆矩阵看一下,是不是分类错误率太高,比如我的数据集,做二分类,结果第二类全分到第一类了。

以上这篇浅谈keras使用预训练模型vgg16分类,损失和准确度不变就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现自动登录人人网并访问最近来访者实例
Sep 26 Python
Django的数据模型访问多对多键值的方法
Jul 21 Python
关于pip的安装,更新,卸载模块以及使用方法(详解)
May 19 Python
PyQt5实现拖放功能
Apr 25 Python
Python采集代理ip并判断是否可用和定时更新的方法
May 07 Python
pip命令无法使用的解决方法
Jun 12 Python
python和shell监控linux服务器的详细代码
Jun 22 Python
python使用tornado实现登录和登出
Jul 28 Python
Python实现正整数分解质因数操作示例
Aug 01 Python
wxPython之wx.DC绘制形状
Nov 19 Python
Python实现平行坐标图的绘制(plotly)方式
Nov 22 Python
python 利用matplotlib在3D空间绘制二次抛物面的案例
Feb 06 Python
python脚本和网页有何区别
Jul 02 #Python
keras:model.compile损失函数的用法
Jul 01 #Python
win10安装python3.6的常见问题
Jul 01 #Python
Python代码需要缩进吗
Jul 01 #Python
导致python中import错误的原因是什么
Jul 01 #Python
详细分析Python垃圾回收机制
Jul 01 #Python
Python自带的IDE在哪里
Jul 01 #Python
You might like
汉字转化为拼音(php版)
2006/10/09 PHP
PHP执行Curl时报错提示CURL ERROR: Recv failure: Connection reset by peer的解决方法
2014/06/26 PHP
分享50个提高PHP执行效率的技巧
2015/12/26 PHP
[原创]PHP字符串中插入子字符串方法总结
2016/05/06 PHP
Apache无法自动跳转却显示目录的解决方法
2020/11/30 PHP
对textarea框的代码调试,而且功能上使用非常方便,酷
2006/06/30 Javascript
ExtJS 2.2.1的grid控件在ie6中的显示问题
2009/05/04 Javascript
Mootools 1.2教程 正则表达式
2009/09/15 Javascript
json格式化/压缩工具 Chrome插件扩展版
2010/05/25 Javascript
基于JQuery的一句代码实现表格的简单筛选
2010/07/26 Javascript
javascript 基础篇3 类,回调函数,内置对象,事件处理
2012/03/14 Javascript
JS运动框架之分享侧边栏动画实例
2015/03/03 Javascript
详解JavaScript节流函数中的Throttle
2016/07/16 Javascript
在html中引入外部js文件,并调用带参函数的方法
2016/10/31 Javascript
Angular的$http的ajax的请求操作(推荐)
2017/01/10 Javascript
js cookie实现记住密码功能
2017/01/17 Javascript
Vue.js 插件开发详解
2017/03/29 Javascript
微信小程序项目实践之验证码倒计时功能
2018/07/18 Javascript
JavaScript cookie原理及使用实例
2020/05/08 Javascript
vue修改Element的el-table样式的4种方法
2020/09/17 Javascript
Python的垃圾回收机制深入分析
2014/07/16 Python
浅谈python对象数据的读写权限
2016/09/12 Python
Python cookbook(数据结构与算法)在字典中将键映射到多个值上的方法
2018/02/18 Python
python和pygame实现简单俄罗斯方块游戏
2021/02/19 Python
Python设计模式之解释器模式原理与用法实例分析
2019/01/10 Python
pandas数据拼接的实现示例
2020/04/16 Python
Python如何实现后端自定义认证并实现多条件登陆
2020/06/22 Python
四风存在的原因分析
2014/02/11 职场文书
优秀学生评语大全
2014/04/25 职场文书
四风个人对照检查材料思想汇报(办公室通用版)
2014/10/07 职场文书
三人合伙协议书范本
2014/10/29 职场文书
2015年政治教研组工作总结
2015/07/22 职场文书
新西兰:最新留学学习计划书写作指南
2019/07/15 职场文书
Mysql 如何批量插入数据
2021/04/06 MySQL
DjangoRestFramework 使用 simpleJWT 登陆认证完整记录
2021/06/22 Python
关于的python五子棋的算法
2022/05/02 Python