keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现远程调用MetaSploit的方法
Aug 22 Python
Python实现将数据框数据写入mongodb及mysql数据库的方法
Apr 02 Python
Python中存取文件的4种不同操作
Jul 02 Python
Python中浅拷贝copy与深拷贝deepcopy的简单理解
Oct 26 Python
Django之提交表单与前后端交互的方法
Jul 19 Python
pytorch构建多模型实例
Jan 15 Python
python GUI库图形界面开发之PyQt5访问系统剪切板QClipboard类详细使用方法与实例
Feb 27 Python
Python接口测试文件上传实例解析
May 22 Python
你需要学会的8个Python列表技巧
Jun 24 Python
pycharm导入源码的具体步骤
Aug 04 Python
详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程
Nov 02 Python
python异常中else的实例用法
Jun 15 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
php Mysql日期和时间函数集合
2007/11/16 PHP
程序员编程十条戒律
2009/07/09 PHP
php生成静态html页面的方法(2种方法)
2015/09/14 PHP
利用PHP命令行模式采集股票趋势信息
2016/08/09 PHP
Jquery Ajax学习实例3 向WebService发出请求,调用方法返回数据
2010/03/16 Javascript
jQuery 过滤not()与filter()实例代码
2012/05/10 Javascript
ajax上传时参数提交不更新等相关问题
2012/12/11 Javascript
datagrid框架的删除添加与修改
2013/04/08 Javascript
JS加载器如何动态加载外部js文件
2016/05/26 Javascript
Bootstrap实现带动画过渡的弹出框
2016/08/09 Javascript
轻松掌握JavaScript单例模式
2016/08/25 Javascript
JavaScript的new date等日期函数在safari中遇到的坑
2016/10/24 Javascript
处理JavaScript值为undefined的7个小技巧
2020/07/28 Javascript
vue 接口请求地址前缀本地开发和线上开发设置方式
2020/08/13 Javascript
[58:29]DOTA2-DPC中国联赛 正赛 Phoenix vs XG BO3 第一场 1月31日
2021/03/11 DOTA
python 简易计算器程序,代码就几行
2009/08/29 Python
python解析xml文件实例分析
2015/05/27 Python
Python实现的栈、队列、文件目录遍历操作示例
2019/05/06 Python
基于Numpy.convolve使用Python实现滑动平均滤波的思路详解
2019/05/16 Python
PyCharm专业最新版2019.1安装步骤(含激活码)
2019/10/09 Python
python 解决flask 图片在线浏览或者直接下载的问题
2020/01/09 Python
Python类class参数self原理解析
2020/11/19 Python
PyCharm 光标变成黑块的解决方式
2021/02/06 Python
如何用Matlab和Python读取Netcdf文件
2021/02/19 Python
浅谈基于Canvas的手绘风格图形库Rough.js
2018/03/19 HTML / CSS
专业销售业务员求职信
2013/11/18 职场文书
三年级学生评语
2014/04/23 职场文书
党在我心中演讲稿
2014/09/02 职场文书
公司证明怎么写
2014/09/22 职场文书
台风停课通知
2015/04/24 职场文书
2015年乡镇财政工作总结
2015/05/19 职场文书
公务员岗前培训心得体会
2016/01/08 职场文书
MySQL之DML语言
2021/04/05 MySQL
SQL Server中常用截取字符串函数介绍
2022/03/16 SQL Server
Python turtle编写简单的球类小游戏
2022/03/31 Python
大型强子对撞机再次重启探索“第五种自然力”
2022/04/29 数码科技