keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python3基础之list列表实例解析
Aug 13 Python
wxPython定时器wx.Timer简单应用实例
Jun 03 Python
Python编程中字符串和列表的基本知识讲解
Oct 14 Python
Python标准库sched模块使用指南
Jul 06 Python
Python 模拟登陆的两种实现方法
Aug 10 Python
Tornado高并发处理方法实例代码
Jan 15 Python
python实现烟花小程序
Jan 30 Python
Python OpenCV中的resize()函数的使用
Jun 20 Python
pycharm双击无响应(打不开问题解决办法)
Jan 10 Python
python字符串常用方法及文件简单读写的操作方法
Mar 04 Python
Python RabbitMQ实现简单的进程间通信示例
Jul 02 Python
python 实现的车牌识别项目
Jan 25 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
php实现读取超大文件的方法
2014/07/28 PHP
PHP简单预防sql注入的方法
2016/09/27 PHP
Laravel中任务调度console使用方法小结
2017/05/07 PHP
一个加载js文件的小脚本
2007/06/28 Javascript
dtree 网页树状菜单及传递对象集合到js内,动态生成节点
2012/04/14 Javascript
可恶的ie8提示缺少id未定义
2014/03/20 Javascript
JavaScript中双叹号(!!)作用示例介绍
2014/04/10 Javascript
JavaScript中getUTCSeconds()方法的使用详解
2015/06/11 Javascript
利用transition实现文字上下抖动的效果
2017/01/21 Javascript
JavaScript中从setTimeout与setInterval到AJAX异步
2017/02/13 Javascript
利用babel将es6语法转es5的简单示例
2017/12/01 Javascript
vue中v-for加载本地静态图片方法
2018/03/03 Javascript
js全屏事件fullscreenchange 实现全屏、退出全屏操作
2019/09/17 Javascript
JavaScript中CreateTextFile函数
2020/08/30 Javascript
vue组件讲解(is属性的用法)模板标签替换操作
2020/09/04 Javascript
微信小程序使用前置摄像头拍照
2020/10/22 Javascript
[54:10]Spirit vs NB Supermajor小组赛 A组败者组决赛 BO3 第一场 6.2
2018/06/03 DOTA
Python计算斗牛游戏概率算法实例分析
2017/09/26 Python
python3中的md5加密实例
2018/05/29 Python
python如何爬取个性签名
2018/06/19 Python
详解python中__name__的意义以及作用
2019/08/07 Python
Python流程控制 if else实现解析
2019/09/02 Python
Python3 使用selenium插件爬取苏宁商家联系电话
2019/12/23 Python
python3实现从kafka获取数据,并解析为json格式,写入到mysql中
2019/12/23 Python
Python 读取WAV音频文件 画频谱的实例
2020/03/14 Python
python文件及目录操作代码汇总
2020/07/08 Python
10个最常见的HTML5面试题 附答案
2016/06/06 HTML / CSS
Spartoo芬兰:欧洲最大的网上鞋店
2016/08/28 全球购物
美国美食礼品篮网站:Gourmet Gift Baskets
2019/12/15 全球购物
实习老师离校感言
2014/02/03 职场文书
遥感技术与仪器求职信
2014/02/22 职场文书
《记金华的双龙洞》教学反思
2014/04/19 职场文书
2014普法依法治理工作总结
2014/12/18 职场文书
机动车交通事故协议书
2015/01/29 职场文书
小学教师工作总结2015
2015/04/07 职场文书
Nginx缓存设置案例详解
2021/09/15 Servers