keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用chardet判断字符编码
May 09 Python
Python AES加密模块用法分析
May 22 Python
Python中使用支持向量机SVM实践
Dec 27 Python
tensorflow使用神经网络实现mnist分类
Sep 08 Python
python utc datetime转换为时间戳的方法
Jan 15 Python
Python从文件中读取数据的方法讲解
Feb 14 Python
关于django 1.10 CSRF验证失败的解决方法
Aug 31 Python
使用Fabric自动化部署Django项目的实现
Sep 27 Python
pytorch常见的Tensor类型详解
Jan 15 Python
python让函数不返回结果的方法
Jun 22 Python
python 实现汉诺塔游戏
Nov 28 Python
python playwrigh框架入门安装使用
Jul 23 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
PHP 5.3 下载时 VC9、VC6、Thread Safe、Non Thread Safe的区别分析
2011/03/28 PHP
php empty() 检查一个变量是否为空
2011/11/10 PHP
如何使用PHP批量去除文件UTF8 BOM信息
2013/08/05 PHP
php文件上传简单实现方法
2015/01/24 PHP
Laravel手动分页实现方法详解
2016/10/09 PHP
PHP使用PHPExcel实现批量上传到数据库的方法
2017/06/08 PHP
javascript据option的value值快速设定初始的selected选项
2007/08/13 Javascript
Jquery倒计时源码分享
2014/05/16 Javascript
使用jQuery在移动页面上添加按钮和给按钮添加图标
2015/12/04 Javascript
JavaScript程序中实现继承特性的方式总结
2016/06/24 Javascript
JS使用正则表达式过滤多个词语并替换为相同长度星号的方法
2016/08/03 Javascript
javascript实现用户点击数量统计
2016/12/25 Javascript
使用layui 渲染table数据表格的实例代码
2018/08/19 Javascript
解决element ui select下拉框不回显数据问题的解决
2019/02/20 Javascript
微信小程序模板消息推送的两种实现方式
2019/08/27 Javascript
JavaScript实现手机号码 3-4-4格式并控制新增和删除时光标的位置
2020/06/02 Javascript
vue实现一个6个输入框的验证码输入组件功能的实例代码
2020/06/29 Javascript
javascript实现图片轮换动作方法
2020/08/07 Javascript
对Python的Django框架中的项目进行单元测试的方法
2016/04/11 Python
使用Python3 编写简单信用卡管理程序
2016/12/21 Python
tensorflow-gpu安装的常见问题及解决方案
2020/01/20 Python
python实现测试工具(二)——简单的ui测试工具
2020/10/19 Python
html5视频播放_动力节点Java学院整理
2017/07/13 HTML / CSS
萨克斯第五大道的折扣店:Saks Fifth Avenue OFF 5TH
2016/08/25 全球购物
夏威夷航空官网:Hawaiian Airlines
2016/09/11 全球购物
加拿大时装零售商:Influence U
2018/12/22 全球购物
李维斯法国官网:Levi’s法国
2019/07/13 全球购物
什么造成了Java里面的异常
2016/04/24 面试题
我爱祖国演讲稿
2014/09/02 职场文书
付款委托书范本
2014/10/05 职场文书
健康状况证明模板
2014/10/23 职场文书
导游词之西安骊山
2019/12/20 职场文书
SQLServer 日期函数大全(小结)
2021/04/08 SQL Server
浅谈Python项目的服务器部署
2021/04/25 Python
利用Java设置Word文本框中的文字旋转方向的实现方法
2021/06/28 Java/Android
深入理解 Golang 的字符串
2022/05/04 Golang