keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的hypot()方法使用简介
May 18 Python
python开发简易版在线音乐播放器
Mar 03 Python
python实现人脸识别经典算法(一) 特征脸法
Mar 13 Python
python自动化生成IOS的图标
Nov 13 Python
python版飞机大战代码分享
Nov 20 Python
python 获取url中的参数列表实例
Dec 18 Python
python 实现多维数组转向量
Nov 30 Python
Windows系统下pycharm中的pip换源
Feb 23 Python
Python实现敏感词过滤的4种方法
Sep 12 Python
python实现移动木板小游戏
Oct 09 Python
Pytorch可视化的几种实现方法
Jun 10 Python
Python+pyaudio实现音频控制示例详解
Jul 23 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
isset和empty的区别
2007/01/15 PHP
php实现图片上传时添加文字和图片水印技巧
2020/04/18 PHP
thinkPHP统计排行与分页显示功能示例
2016/12/02 PHP
PHP7原生MySQL数据库操作实现代码
2020/07/03 PHP
JavaScript 中的replace方法说明
2007/04/13 Javascript
jquery链式操作的正确使用方法
2014/01/06 Javascript
javascript中的undefined和not defined区别示例介绍
2014/02/26 Javascript
js实现ArrayList功能附实例代码
2014/10/29 Javascript
JavaScript常用的返回,自动跳转,刷新,关闭语句汇总
2015/01/13 Javascript
每天一篇javascript学习小结(Date对象)
2015/11/13 Javascript
JavaScript获取浏览器信息的方法
2015/11/20 Javascript
jQuery Mobile漏洞会有跨站脚本攻击风险
2017/02/12 Javascript
利用nodejs监控文件变化并使用sftp上传到服务器
2017/02/18 NodeJs
微信小程序实现顶部选项卡(swiper)
2020/06/19 Javascript
从零开始搭建webpack+react开发环境的详细步骤
2018/05/18 Javascript
通过vue-cli3构建一个SSR应用程序的方法
2018/09/13 Javascript
nodejs使用Sequelize框架操作数据库的实现
2020/10/21 NodeJs
[01:16:13]DOTA2-DPC中国联赛 正赛 SAG vs Dragon BO3 第一场 2月22日
2021/03/11 DOTA
[01:30:15]DOTA2-DPC中国联赛 正赛 Ehome vs Aster BO3 第二场 2月2日
2021/03/11 DOTA
Python实现telnet服务器的方法
2015/07/10 Python
Python实现批量将word转html并将html内容发布至网站的方法
2015/07/14 Python
Python随机数random模块使用指南
2016/09/09 Python
Python3导入CSV文件的实例(跟Python2有些许的不同)
2018/06/22 Python
Python实现的爬取百度文库功能示例
2019/02/16 Python
Python内置类型性能分析过程实例
2020/01/29 Python
松本清官方海外旗舰店:日本最大的药妆连锁店
2017/11/21 全球购物
工程安全员岗位职责
2014/03/09 职场文书
公司委托书怎么写
2014/08/02 职场文书
四风问题查摆剖析材料
2014/10/11 职场文书
中学生的1000字检讨书
2014/10/11 职场文书
无子女夫妻离婚协议书(4篇)
2014/10/20 职场文书
逃课检讨书怎么写
2015/01/01 职场文书
小学教师师德师风自我评价
2015/03/04 职场文书
小学数学教学随笔
2015/08/14 职场文书
2016初一新生军训心得体会
2016/01/11 职场文书
2016参观监狱警示教育活动心得体会
2016/01/15 职场文书