keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析xml模块封装代码
Feb 07 Python
python定时检查某个进程是否已经关闭的方法
May 20 Python
python 输出上个月的月末日期实例
Apr 11 Python
Python使用装饰器模拟用户登陆验证功能示例
Aug 24 Python
Python操作mongodb数据库的方法详解
Dec 08 Python
Python实现的线性回归算法示例【附csv文件下载】
Dec 29 Python
使用python opencv对目录下图片进行去重的方法
Jan 12 Python
python实现浪漫的烟花秀
Jan 30 Python
python Tcp协议发送和接收信息的例子
Jul 22 Python
python scipy卷积运算的实现方法
Sep 16 Python
PyCharm+PyQt5+QtDesigner配置详解
Aug 12 Python
Python中for后接else的语法使用
May 18 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
php中iconv函数使用方法
2008/05/24 PHP
ThinkPHP3.1基础知识快速入门
2014/06/19 PHP
yiic命令时提示“php.exe”不是内部或外部命令的解决方法
2014/12/18 PHP
Java中final关键字详解
2015/08/10 PHP
Zend Framework教程之Bootstrap类用法概述
2016/03/14 PHP
详解PHP中cookie和session的区别及cookie和session用法小结
2016/06/12 PHP
Yii2实现中国省市区三级联动实例
2017/02/08 PHP
php实现的生成迷宫与迷宫寻址算法完整实例
2017/11/06 PHP
同一个表单 根据要求递交到不同页面的实现方法小结
2009/08/05 Javascript
jQuery功能函数详解
2015/02/01 Javascript
微信浏览器内置JavaScript对象WeixinJSBridge使用实例
2015/05/25 Javascript
深入探究使JavaScript动画流畅的一些方法
2015/06/30 Javascript
Javascript验证方法大全
2015/09/21 Javascript
js中el表达式的使用和非空判断方法
2018/03/28 Javascript
vue项目引入字体.ttf的方法
2018/09/28 Javascript
解决angularjs service中依赖注入$scope报错的问题
2018/10/02 Javascript
使用nvm和nrm优化node.js工作流的方法
2019/01/17 Javascript
vue-socket.io接收不到数据问题的解决方法
2020/05/13 Javascript
vue使用openlayers实现移动点动画
2020/09/24 Javascript
python实现JAVA源代码从ANSI到UTF-8的批量转换方法
2015/08/10 Python
win系统下为Python3.5安装flask-mongoengine 库
2016/12/20 Python
批量获取及验证HTTP代理的Python脚本
2017/04/23 Python
使用Python写一个小游戏
2018/04/02 Python
使用pandas批量处理矢量化字符串的实例讲解
2018/07/10 Python
详解python编译器和解释器的区别
2019/06/24 Python
Python 支持向量机分类器的实现
2020/01/15 Python
CSS3弹性伸缩布局之box布局
2016/07/12 HTML / CSS
美国蔬菜和植物种子公司:Burpee
2017/02/01 全球购物
介绍一下Ruby的多线程处理
2013/02/01 面试题
《厄运打不垮的信念》教学反思
2014/04/13 职场文书
英语教育专业自荐信
2014/05/29 职场文书
我的中国梦演讲稿高中篇
2014/08/19 职场文书
营业员岗位职责范本
2015/04/14 职场文书
golang特有程序结构入门教程
2021/06/02 Python
CSS使用Flex和Grid布局实现3D骰子
2022/08/05 HTML / CSS
Java使用HttpClient实现文件下载
2022/08/14 Java/Android