keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
通过实例浅析Python对比C语言的编程思想差异
Aug 30 Python
Python爬取附近餐馆信息代码示例
Dec 09 Python
深入浅析Python中的yield关键字
Jan 24 Python
Numpy中转置transpose、T和swapaxes的实例讲解
Apr 17 Python
Python使用matplotlib模块绘制图像并设置标题与坐标轴等信息示例
May 04 Python
python numpy和list查询其中某个数的个数及定位方法
Jun 27 Python
Python 打印中文字符的三种方法
Aug 14 Python
python学习--使用QQ邮箱发送邮件代码实例
Apr 16 Python
Python中利用LSTM模型进行时间序列预测分析的实现
Jul 26 Python
如何基于python操作excel并获取内容
Dec 24 Python
python和js交互调用的方法
Jun 23 Python
python 标准库原理与用法详解之os.path篇
Oct 24 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
E路文章系统PHP
2006/12/11 PHP
给初学者的30条PHP最佳实践(荒野无灯)
2011/08/02 PHP
php中file_get_content 和curl以及fopen 效率分析
2014/09/19 PHP
超详细的php用户注册页面填写信息完整实例(附源码)
2015/11/17 PHP
隐藏Nginx或Apache以及PHP的版本号的方法
2016/01/03 PHP
浅谈PHP Cookie处理函数
2016/06/10 PHP
Laravel配置全局公共函数的方法步骤
2019/05/09 PHP
Tab页界面,用jQuery及Ajax技术实现
2009/09/21 Javascript
用JQuery实现表格隔行变色和突出显示当前行的代码
2012/02/10 Javascript
js实现可拖动DIV的方法
2013/12/17 Javascript
js实现使用鼠标拖拽切换图片的方法
2015/05/04 Javascript
JavaScript开发Chrome浏览器扩展程序UI的教程
2016/05/16 Javascript
JavaScript实现广告弹窗效果
2016/08/09 Javascript
详解vue跨组件通信的几种方法
2017/06/15 Javascript
微信小程序图片选择区域裁剪实现方法
2017/12/02 Javascript
微信小程序实现给嵌套template模板传递数据的方式总结
2017/12/18 Javascript
JavaScript数组排序reverse()和sort()方法详解
2017/12/24 Javascript
vue富文本编辑器组件vue-quill-edit使用教程
2018/09/21 Javascript
[00:32]2016完美“圣”典风云人物:Maybe宣传片
2016/12/05 DOTA
Python面向对象之静态属性、类方法与静态方法分析
2018/08/24 Python
Python 判断奇数偶数的方法
2018/12/20 Python
Python中pymysql 模块的使用详解
2019/08/12 Python
Django接收自定义http header过程详解
2019/08/23 Python
Python 函数用法简单示例【定义、参数、返回值、函数嵌套】
2019/09/20 Python
python之列表推导式的用法
2019/11/29 Python
Cpython解释器中的GIL全局解释器锁
2020/11/09 Python
检测浏览器对HTML5和CSS3支持度的方法
2015/06/25 HTML / CSS
通过HTML5规范搞定i、em、b、strong元素的区别
2017/03/04 HTML / CSS
LightInTheBox西班牙站点:全球商品在线采购
2016/09/22 全球购物
Booking.com西班牙:全球酒店预订
2018/03/30 全球购物
小学校园活动策划
2014/01/30 职场文书
秋游活动策划方案
2014/02/16 职场文书
酒店员工职业生涯规划
2014/02/25 职场文书
年终考核实施方案
2014/05/26 职场文书
2015年班组长工作总结
2015/04/10 职场文书
小学班主任工作经验交流材料
2015/11/02 职场文书