解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题


Posted in Python onJune 17, 2020

遇到的问题

解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题

当时自己在使用Alexnet训练图像分类问题时,会出现损失在一个epoch中增加,换做下一个epoch时loss会骤然降低,一开始这个问题没有一点头绪,我数据也打乱了,使用的是tf.train.shuffle_batch

在capacity中设置一个值,比如是1000吧,每次取一千个数据后将这一千个数据打乱,本次使用的数据集就是每个种类1000多,而我加载数据时是一类一类加载的,这就造成了每一批次的开始可以跟前一类数据做打乱处理,但是在中间数据并不能达到充分的shuffle

解决问题

在加载数据集的时候用numpy中的shuffle将数据集充分的打乱后在读入tfrecord中,之后读取的时候使用tf.tain.shuffle_batch和使用tf.train.batch就没有区别了。另外capacity这个数值不益设置过大,会对自己的电脑造成压力。

补充知识:MATLAB中使用AlexNet、VGG、GoogLeNet进行迁移学习

直接贴代码,具体用法见注释:

clc;clear;

net = alexnet; %加载在ImageNet上预训练的网络模型
imageInputSize = [227 227 3];
%加载图像
allImages = imageDatastore('.\data227Alexnet',...
 'IncludeSubfolders',true,...
 'LabelSource','foldernames');
 %划分训练集和验证集
 [training_set,validation_set] = splitEachLabel(allImages,0.7,'randomized');
 %由于原始网络全连接层1000个输出,显然不适用于我们的分类任务,因此在这里替换
layersTransfer = net.Layers(1:end-3);
categories(training_set.Labels)
numClasses = numel(categories(training_set.Labels));
%新的网络
layers = [
 layersTransfer
 fullyConnectedLayer(numClasses,'Name', 'fc','WeightLearnRateFactor',1,'BiasLearnRateFactor',1)
 softmaxLayer('Name', 'softmax')
 classificationLayer('Name', 'classOutput')];

lgraph = layerGraph(layers);
plot(lgraph)
%对数据集进行扩增
augmented_training_set = augmentedImageSource(imageInputSize,training_set);

opts = trainingOptions('adam', ...
 'MiniBatchSize', 32,... % mini batch size, limited by GPU RAM, default 100 on Titan, 500 on P6000
 'InitialLearnRate', 1e-4,... % fixed learning rate
 'LearnRateSchedule','piecewise',...
 'LearnRateDropFactor',0.25,...
 'LearnRateDropPeriod',10,...
 'L2Regularization', 1e-4,... constraint
 'MaxEpochs',20,..
 'ExecutionEnvironment', 'gpu',...
 'ValidationData', validation_set,...
 'ValidationFrequency',80,...
 'ValidationPatience',8,...
 'Plots', 'training-progress')

net = trainNetwork(augmented_training_set, lgraph, opts);

save Alex_Public_32.mat net

[predLabels,predScores] = classify(net, validation_set);
plotconfusion(validation_set.Labels, predLabels)
PerItemAccuracy = mean(predLabels == validation_set.Labels);
title(['overall per image accuracy ',num2str(round(100*PerItemAccuracy)),'%'])

MATLAB中训练神经网络一个非常大的优势就是训练过程中各项指标的可视化,并且最终也会生成一个混淆矩阵显示验证集的结果。

以上这篇解决Alexnet训练模型在每个epoch中准确率和loss都会一升一降问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python简单计算文件夹大小的方法
Jul 14 Python
Python中文件I/O高效操作处理的技巧分享
Feb 04 Python
Python爬虫工程师面试问题总结
Mar 22 Python
Python DataFrame 设置输出不显示index(索引)值的方法
Jun 07 Python
selenium在执行phantomjs的API并获取执行结果的方法
Dec 17 Python
Python函数的参数常见分类与用法实例详解
Mar 30 Python
python3 打印输出字典中特定的某个key的方法示例
Jul 06 Python
python实现扫雷游戏
Mar 03 Python
TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法
Apr 19 Python
python 实现任务管理清单案例
Apr 25 Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 Python
如何解决安装python3.6.1失败
Jul 01 Python
Java如何基于wsimport调用wcf接口
Jun 17 #Python
使用keras内置的模型进行图片预测实例
Jun 17 #Python
Python虚拟环境库virtualenvwrapper安装及使用
Jun 17 #Python
基于TensorFlow的CNN实现Mnist手写数字识别
Jun 17 #Python
Keras 加载已经训练好的模型进行预测操作
Jun 17 #Python
基于Tensorflow的MNIST手写数字识别分类
Jun 17 #Python
Kears 使用:通过回调函数保存最佳准确率下的模型操作
Jun 17 #Python
You might like
PHP 安全检测代码片段(分享)
2013/07/05 PHP
老生常谈PHP 文件写入和读取(必看篇)
2017/05/22 PHP
基于 Swoole 的微信扫码登录功能实现代码
2018/01/15 PHP
PHP两个n位的二进制整数相加问题的解决
2018/08/26 PHP
jquery $.ajax入门应用一
2008/11/19 Javascript
JavaScript设置FieldSet展开与收缩
2009/05/15 Javascript
jquery attr 设定src中含有&(宏)符号问题的解决方法
2011/07/26 Javascript
js无刷新操作table的行和列
2014/03/27 Javascript
对 jQuery 中 data 方法的误解分析
2014/06/18 Javascript
在easyUI开发中,出现jquery.easyui.min.js函数库问题的解决办法
2015/09/11 Javascript
Node.js实现JS文件合并小工具
2016/02/02 Javascript
JS表单数据验证的正则表达式(常用)
2017/02/18 Javascript
layer弹出层框架alert与msg详解
2017/03/14 Javascript
详解Angular Reactive Form 表单验证
2017/07/06 Javascript
解决webpack打包速度慢的解决办法汇总
2017/07/06 Javascript
vue初始化动画加载的实例
2018/09/01 Javascript
vue.js中导出Excel表格的案例分析
2019/06/11 Javascript
使用zrender.js绘制体温单效果
2019/10/31 Javascript
AutoJs实现刷宝短视频的思路详解
2020/05/22 Javascript
ES6函数和数组用法实例分析
2020/05/23 Javascript
Python中类型检查的详细介绍
2017/02/13 Python
深入理解Python分布式爬虫原理
2017/11/23 Python
widows下安装pycurl并利用pycurl请求https地址的方法
2018/10/15 Python
python3的输入方式及多组输入方法
2018/10/17 Python
python抓取搜狗微信公众号文章
2019/04/01 Python
windows下numpy下载与安装图文教程
2019/04/02 Python
python高斯分布概率密度函数的使用详解
2019/07/10 Python
基于Django静态资源部署404的解决方法
2019/07/28 Python
10行Python代码计算汽车数量的实现方法
2019/10/23 Python
Python模拟FTP文件服务器的操作方法
2020/02/18 Python
Python configparser模块应用过程解析
2020/08/14 Python
CSS3 box-shadow属性实例详解
2020/06/19 HTML / CSS
一名女生的自荐信
2013/12/08 职场文书
团支部推优材料
2014/05/21 职场文书
pytorch实现ResNet结构的实例代码
2021/05/17 Python
MySQL之select、distinct、limit的使用
2021/11/11 MySQL