keras训练曲线,混淆矩阵,CNN层输出可视化实例


Posted in Python onJune 15, 2020

训练曲线

def show_train_history(train_history, train_metrics, validation_metrics):
 plt.plot(train_history.history[train_metrics])
 plt.plot(train_history.history[validation_metrics])
 plt.title('Train History')
 plt.ylabel(train_metrics)
 plt.xlabel('Epoch')
 plt.legend(['train', 'validation'], loc='upper left')

# 显示训练过程
def plot(history):
 plt.figure(figsize=(12, 4))
 plt.subplot(1, 2, 1)
 show_train_history(history, 'acc', 'val_acc')
 plt.subplot(1, 2, 2)
 show_train_history(history, 'loss', 'val_loss')
 plt.show()

效果:

plot(history)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

混淆矩阵

def plot_confusion_matrix(cm, classes,
    title='Confusion matrix',
    cmap=plt.cm.jet):
 cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(classes))
 plt.xticks(tick_marks, classes, rotation=45)
 plt.yticks(tick_marks, classes)
 thresh = cm.max() / 2.
 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
 plt.text(j, i, '{:.2f}'.format(cm[i, j]), horizontalalignment="center",
   color="white" if cm[i, j] > thresh else "black")
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')
 plt.show()

# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
 predictions = model.predict_classes(x_val)
 truelabel = y_val.argmax(axis=-1) # 将one-hot转化为label
 conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
 plt.figure()
 plot_confusion_matrix(conf_mat, range(np.max(truelabel)+1))

其中y_val以one-hot形式输入

效果:

x_val.shape # (25838, 48, 48, 1)
y_val.shape # (25838, 7)
plot_confuse(model, x_val, y_val)

keras训练曲线,混淆矩阵,CNN层输出可视化实例

CNN层输出可视化

# 卷积网络可视化
def visual(model, data, num_layer=1):
 # data:图像array数据
 # layer:第n层的输出
 data = np.expand_dims(data, axis=0) # 开头加一维
 layer = keras.backend.function([model.layers[0].input], [model.layers[num_layer].output])
 f1 = layer([data])[0]
 num = f1.shape[-1]
 plt.figure(figsize=(8, 8))
 for i in range(num):
 plt.subplot(np.ceil(np.sqrt(num)), np.ceil(np.sqrt(num)), i+1)
 plt.imshow(f1[0, :, :, i] * 255, cmap='gray')
 plt.axis('off')
 plt.show()

num_layer : 显示第n层的输出

效果

visual(model, data, 1) # 卷积层
visual(model, data, 2) # 激活层
visual(model, data, 3) # 规范化层
visual(model, data, 4) # 池化层

keras训练曲线,混淆矩阵,CNN层输出可视化实例

补充知识:Python sklearn.cross_validation.train_test_split及混淆矩阵实现

sklearn.cross_validation.train_test_split随机划分训练集和测试集

一般形式:

train_test_split是交叉验证中常用的函数,功能是从样本中随机的按比例选取train data和testdata,形式为:

X_train,X_test, y_train, y_test =
cross_validation.train_test_split(train_data,train_target,test_size=0.4, random_state=0)

参数解释:

train_data:所要划分的样本特征集

train_target:所要划分的样本结果

test_size:样本占比,如果是整数的话就是样本的数量

random_state:是随机数的种子。

随机数种子:其实就是该组随机数的编号,在需要重复试验的时候,保证得到一组一样的随机数。比如你每次都填1,其他参数一样的情况下你得到的随机数组是一样的。但填0或不填,每次都会不一样。随机数的产生取决于种子,随机数和种子之间的关系遵从以下两个规则:种子不同,产生不同的随机数;种子相同,即使实例不同也产生相同的随机数。

示例

fromsklearn.cross_validation import train_test_split
train= loan_data.iloc[0: 55596, :]
test= loan_data.iloc[55596:, :]
# 避免过拟合,采用交叉验证,验证集占训练集20%,固定随机种子(random_state)
train_X,test_X, train_y, test_y = train_test_split(train,
             target,
             test_size = 0.2,
             random_state = 0)
train_y= train_y['label']
test_y= test_y['label']

plot_confusion_matrix.py(混淆矩阵实现实例)

print(__doc__)

import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
 plt.imshow(cm, interpolation='nearest', cmap=cmap)
 plt.title(title)
 plt.colorbar()
 tick_marks = np.arange(len(iris.target_names))
 plt.xticks(tick_marks, iris.target_names, rotation=45)
 plt.yticks(tick_marks, iris.target_names)
 plt.tight_layout()
 plt.ylabel('True label')
 plt.xlabel('Predicted label')

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)
print('Confusion matrix, without normalization')
print(cm)
plt.figure()
plot_confusion_matrix(cm)

# Normalize the confusion matrix by row (i.e by the number of samples
# in each class)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')

plt.show()

以上这篇keras训练曲线,混淆矩阵,CNN层输出可视化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现生命游戏的示例代码(Game of Life)
Jan 24 Python
Python多进程原理与用法分析
Aug 21 Python
对pandas中iloc,loc取数据差别及按条件取值的方法详解
Nov 06 Python
解决pytorch GPU 计算过程中出现内存耗尽的问题
Aug 19 Python
解决python明明pip安装成功却找不到包的问题
Aug 28 Python
通过python扫描二维码/条形码并打印数据
Nov 14 Python
Python list运算操作代码实例解析
Jan 20 Python
Python 为什么推荐蛇形命名法原因浅析
Jun 18 Python
详解anaconda安装步骤
Nov 23 Python
Django多个app urls配置代码实例
Nov 26 Python
linux中nohup和后台运行进程查看及终止
Jun 24 Python
python热力图实现的完整实例
Jun 25 Python
Python3 requests模块如何模仿浏览器及代理
Jun 15 #Python
keras读取训练好的模型参数并把参数赋值给其它模型详解
Jun 15 #Python
keras得到每层的系数方式
Jun 15 #Python
Python类及获取对象属性方法解析
Jun 15 #Python
在Keras中实现保存和加载权重及模型结构
Jun 15 #Python
简单了解Python多态与属性运行原理
Jun 15 #Python
Python类super()及私有属性原理解析
Jun 15 #Python
You might like
php获取textarea的值并处理回车换行的方法
2014/10/20 PHP
JSON用法之将PHP数组转JS数组,JS如何接收PHP数组
2015/10/08 PHP
PHP面向对象五大原则之开放-封闭原则(OCP)详解
2018/04/04 PHP
ThinkPHP防止重复提交表单的方法实例分析
2018/05/10 PHP
Ext.get() 和 Ext.query()组合使用实现最灵活的取元素方式
2011/09/26 Javascript
基于jQuery捕获超链接事件进行局部刷新代码
2012/05/10 Javascript
用客户端js实现带省略号的分页
2013/04/27 Javascript
JavaScript实现找出字符串中第一个不重复的字符
2014/09/03 Javascript
JS实现向表格中动态添加行的方法
2015/03/30 Javascript
JS实现网页每隔3秒弹出一次对话框的方法
2015/11/09 Javascript
jQuery实现图片文字淡入淡出效果
2015/12/21 Javascript
使用postMesssage()实现跨域iframe页面间的信息传递方法
2016/03/29 Javascript
前端面试题及答案整理(二)
2016/08/26 Javascript
Jqprint实现页面打印
2017/01/06 Javascript
创建一般js对象的几种方式
2017/01/19 Javascript
JavaScript 栈的详解及实例代码
2017/01/22 Javascript
vue 计时器组件的实现代码
2017/09/14 Javascript
详解用函数式编程对JavaScript进行断舍离
2017/09/18 Javascript
JS设计模式之策略模式概念与用法分析
2018/02/05 Javascript
详解JS中统计函数执行次数与执行时间
2018/09/04 Javascript
vue.js循环radio的实例
2019/11/07 Javascript
python实现bitmap数据结构详解
2014/02/17 Python
使用Python微信库itchat获得好友和群组已撤回的消息
2018/06/24 Python
关于python tushare Tkinter构建的简单股票可视化查询系统(Beta v0.13)
2020/10/19 Python
python 监控服务器是否有人远程登录(详细思路+代码)
2020/12/18 Python
html5使用canvas实现弹幕功能示例
2017/09/11 HTML / CSS
代码中finally中的代码会不会执行
2012/02/06 面试题
小学生红领巾广播稿
2014/01/21 职场文书
学习党章心得体会2016
2016/01/15 职场文书
某某幼儿园的教育教学管理调研分析报告
2019/11/29 职场文书
Nginx已编译的nginx-添加新模块
2021/04/01 Servers
python中os.path.join()函数实例用法
2021/05/26 Python
MySQL5.7并行复制原理及实现
2021/06/03 MySQL
前端JavaScript大管家 package.json
2021/11/02 Javascript
Vue自定义铃声提示音组件的实现
2022/01/22 Vue.js
SpringBoot2零基础到精通之数据库专项精讲
2022/03/22 Java/Android