keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 文件重命名工具代码
Jul 26 Python
详解K-means算法在Python中的实现
Dec 05 Python
Python使用三种方法实现PCA算法
Dec 12 Python
特征脸(Eigenface)理论基础之PCA主成分分析法
Mar 13 Python
linux下python使用sendmail发送邮件
May 22 Python
Python实现监控Nginx配置文件的不同并发送邮件报警功能示例
Feb 26 Python
python中的数据结构比较
May 13 Python
如何使用pyinstaller打包32位的exe程序
May 26 Python
python多线程http压力测试脚本
Jun 25 Python
Python实现代码统计工具
Sep 19 Python
python实现批量提取指定文件夹下同类型文件
Apr 05 Python
彻底弄懂Python中的回调函数(callback)
Jun 25 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
PHP 采集心得技巧
2009/05/15 PHP
PHP中uploaded_files函数使用方法详解
2011/03/09 PHP
Windows下部署Apache+PHP+MySQL运行环境实战
2012/08/31 PHP
php笔记之:有规律大文件的读取与写入的分析
2013/04/26 PHP
php指定函数参数默认值示例代码
2013/12/04 PHP
php eval函数一句话木马代码
2015/05/21 PHP
php通过排列组合实现1到9数字相加都等于20的方法
2015/08/03 PHP
PHP实现文件上传和多文件上传
2015/12/24 PHP
php获取一定范围内取N个不重复的随机数
2016/05/28 PHP
php mysql procedure实现获取多个结果集的方法【基于thinkPHP】
2016/11/09 PHP
PHP大文件分割上传 PHP分片上传
2017/08/28 PHP
Aster vs KG BO3 第三场2.19
2021/03/10 DOTA
Javascript 函数中的参数使用分析
2010/03/27 Javascript
jquery键盘事件介绍
2011/01/31 Javascript
30个让人兴奋的视差滚动(Parallax Scrolling)效果网站
2012/03/04 Javascript
通过location.replace禁止浏览器后退防止重复提交
2014/09/04 Javascript
Javascript基础教程之数据类型 (布尔型 Boolean)
2015/01/18 Javascript
js实现文本框宽度自适应文本宽度的方法
2015/08/13 Javascript
jQuery实现菜单栏导航效果
2017/08/15 jQuery
JS实现的找零张数最小问题示例
2017/11/28 Javascript
[01:38]DOTA2 2015国际邀请赛中国区预选赛 Showopen
2015/06/01 DOTA
python matplotlib 注释文本箭头简单代码示例
2018/01/08 Python
深入理解Python异常处理的哲学
2019/02/01 Python
在vscode中配置python环境过程解析
2019/09/28 Python
如何基于Python创建目录文件夹
2019/12/31 Python
Django Channel实时推送与聊天的示例代码
2020/04/30 Python
python中wheel的用法整理
2020/06/15 Python
Linux开机引导的步骤是什么
2014/02/26 面试题
产品推广策划方案
2014/05/10 职场文书
工程承包协议书
2014/10/20 职场文书
2014年销售助理工作总结
2014/12/01 职场文书
教师个人事迹材料
2014/12/17 职场文书
2015年护士节慰问信
2015/03/23 职场文书
2015年社区关工委工作总结
2015/04/03 职场文书
Python Pandas pandas.read_sql函数实例用法
2021/06/21 Python
JavaScript中reduce()的用法
2022/05/11 Javascript