使用Tensorflow实现可视化中间层和卷积层


Posted in Python onJanuary 24, 2020

为了查看网络训练的效果或者便于调参、更改结构等,我们常常将训练网络过程中的loss、accurcy等参数。

除此之外,有时我们也想要查看训练好的网络中间层输出和卷积核上面表达了什么内容,这可以帮助我们思考CNN的内在机制、调整网络结构或者把这些可视化内容贴在论文当中辅助说明训练的效果等。

中间层和卷积核的可视化有多种方法,整理如下:

1. 以矩阵(matrix)格式手动输出图像:

用简单的LeNet网络训练MNIST数据集作为示例:

x = tf.placeholder(tf.float32, [None, 784]) 

x_image = tf.reshape(x, [-1,28,28,1])    
W_conv1 = weight_variable([5, 5, 1, 32]) # 第一个卷积层的32个卷积核  
b_conv1 = bias_variable([32])  
# 第一个卷积层:  
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool(h_conv1)  # 第一个池化层

训练结束后,第一个卷积层共有32个5*5大小的卷积核:W_conv1,要可视化第10个卷积核:

from PIL import Image
import numpy as np
#from mnist_try001 import W_conv1

img1 = (W_conv1.eval()) # 将张量转换为numpy数组
W_conv1_10 = img1[:,:,:,9] 

W_conv1_10 = np.asmatrix(W_conv1_10) # 将数组转换为矩阵格式
W_conv1_10_visual = Image.fromarray(W_conv1_10 * 255.0 / W_conv1_10.max()) # 像素值归一化,Image.fromarray方法的输入范围是[0~255]
W_conv1_10_visual.show()

2. 通过反卷积方式输出中间层和卷积核图像:

import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

x = tf.placeholder(tf.float32, [None, 784])
mnist = input_data.read_data_sets('/TensorflowCode/MNIST_data', one_hot=True)

h_conv2 = tf.nn.relu(tf.nn.conv2d(h_pool1, W_conv2, strides=[1, 1, 1, 1], padding='SAME') + b_conv2) #14*14*64
# 可视化第二层输出的图像
input_image = mnist.train.images[100] # 输入一幅指定图像,mnist.train.images[100]尺寸为[784,],即1维:[1,784]
conv2 = sess.run(h_conv2, feed_dict={x:input_image}) # [64, 14, 14 ,1] 若前面网络中加入了dropout,这里的feed_dict中不要忘记加上keep_prob: 0.5
conv2 = sess.run(tf.reshape(conv2 , [64, 1, 14, 14]))
conv2 = np.sum(conv2,axis = 0) # 对中间层图像各通道求和,作为输出图像
h_conv1 = np.asmatrix(h_conv1) # 将conv2数组转换成矩阵格式
h_conv1 = Image.fromarray(h_conv1 * 255.0 / h_conv1.max()) # 矩阵数值归一化
h_conv1.show() # 输出14*14的灰度图像

可视化卷积核和上面的方法完全一样,把h_conv2改成卷积核就可以了(如W_conv1_10),可以同是输出多个卷积核。

中间层图像如下:(已经完全看不出是数字了)

使用Tensorflow实现可视化中间层和卷积层

或者用 matplotlib.pyplot代替上面的Image方法,可以直接输出彩色图像:

# 输出第一层的32个卷积核(5×5*32)
import matplotlib.pyplot as plt

input_image = mnist.train.images[100]
W_conv1 = sess.run(W_conv1, feed_dict={x:input_image})   
W_conv1 = sess.run(tf.reshape(conv1_16, [32, 1, 5, 5]))
fig1,ax1 = plt.subplots(nrows=1, ncols=32, figsize = (32,1))
for i in range(32):
  ax1[i].imshow( W_conv1[i][0])           
plt.title('W_conv1 32×5×5')
plt.show()

以上这篇使用Tensorflow实现可视化中间层和卷积层就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中合并两个文本文件并按照姓名首字母排序的例子
Apr 25 Python
Python入门之三角函数全解【收藏】
Nov 08 Python
通过python+selenium3实现浏览器刷简书文章阅读量
Dec 26 Python
python使用生成器实现可迭代对象
Mar 20 Python
python3 pandas 读取MySQL数据和插入的实例
Apr 20 Python
对TensorFlow中的variables_to_restore函数详解
Jul 30 Python
Python中logging实例讲解
Jan 17 Python
几个适合python初学者的简单小程序,看完受益匪浅!(推荐)
Apr 16 Python
python语言元素知识点详解
May 15 Python
Python小白学习爬虫常用请求报头
Jun 03 Python
python使用建议与技巧分享(一)
Aug 17 Python
python turtle绘图命令及案例
Nov 23 Python
tensorflow获取预训练模型某层参数并赋值到当前网络指定层方式
Jan 24 #Python
keras获得某一层或者某层权重的输出实例
Jan 24 #Python
浅谈keras的深度模型训练过程及结果记录方式
Jan 24 #Python
关于Keras模型可视化教程及关键问题的解决
Jan 24 #Python
基于keras 模型、结构、权重保存的实现
Jan 24 #Python
Python 文件数据读写的具体实现
Jan 24 #Python
利用keras加载训练好的.H5文件,并实现预测图片
Jan 24 #Python
You might like
PHP 超链接 抓取实现代码
2009/06/29 PHP
PHP获取用户的浏览器与操作系统信息的代码
2012/09/04 PHP
php批量删除数据库下指定前缀的表以prefix_为例
2014/08/24 PHP
用函数式编程技术编写优美的 JavaScript_ibm
2008/05/16 Javascript
番茄的表单验证类代码修改版
2008/07/18 Javascript
默认让页面的第一个控件选中的javascript代码
2009/12/26 Javascript
js获取单元格自定义属性值的代码(IE/Firefox)
2010/04/05 Javascript
Javascript 判断是否存在函数的方法
2013/01/03 Javascript
JavaScript lastIndexOf方法入门实例(计算指定字符在字符串中最后一次出现的位置)
2014/10/17 Javascript
JS实现的另类手风琴效果网页内容切换代码
2015/09/08 Javascript
jquery 实现回车登录详解及实例代码
2016/10/23 Javascript
jquery文字填写自动高度的实现方法
2016/11/07 Javascript
jQuery对table表格进行增删改查
2020/12/22 Javascript
折叠菜单及选择器的运用
2017/02/03 Javascript
详解Vue+axios+Node+express实现文件上传(用户头像上传)
2018/08/10 Javascript
nodeJS进程管理器pm2的使用
2019/01/09 NodeJs
使用vue-cli脚手架工具搭建vue-webpack项目
2019/01/14 Javascript
[01:04:01]2014 DOTA2国际邀请赛中国区预选赛 5 23 CIS VS DT第一场
2014/05/24 DOTA
浅谈scrapy 的基本命令介绍
2017/06/13 Python
python批量获取html内body内容的实例
2019/01/02 Python
python中多个装饰器的调用顺序详解
2019/07/16 Python
Python队列RabbitMQ 使用方法实例记录
2019/08/05 Python
python实现对图片进行旋转,放缩,裁剪的功能
2019/08/07 Python
python语言线程标准库threading.local解读总结
2019/11/10 Python
基于python实现查询ip地址来源
2020/06/02 Python
CSS3 简单又实用的5个属性
2010/03/04 HTML / CSS
如何使用canvas绘制可移动网格的示例代码
2020/12/14 HTML / CSS
Made in Design德国:设计师家具、灯具和装饰
2019/10/31 全球购物
竞聘医务工作人员的自我评价分享
2013/11/04 职场文书
个人实习生的自我评价
2014/02/16 职场文书
银行服务感言
2014/03/01 职场文书
男女朋友协议书
2014/04/23 职场文书
详解GaussDB for MySQL性能优化
2021/05/18 MySQL
浅谈如何提高PHP代码质量之端到端集成测试
2021/05/28 PHP
JavaScript 定时器详情
2021/11/11 Javascript
Python+Selenium实现抖音、快手、B站、小红书、微视、百度好看视频、西瓜视频、微信视频号、搜狐视频、一点号、大风号、趣头条等短视频自动发布
2022/04/13 Python