pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础入门详解(文件输入/输出 内建类型 字典操作使用方法)
Dec 08 Python
Python实现批量把SVG格式转成png、pdf格式的代码分享
Aug 21 Python
python静态方法实例
Jan 14 Python
python开发环境PyScripter中文乱码问题解决方案
Sep 11 Python
教你用Python脚本快速为iOS10生成图标和截屏
Sep 22 Python
Python基于win32ui模块创建弹出式菜单示例
May 09 Python
Python3连接SQLServer、Oracle、MySql的方法
Jun 28 Python
python是否适合网页编程详解
Oct 04 Python
pytorch 实现tensor与numpy数组转换
Dec 27 Python
Python小整数对象池和字符串intern实例解析
Mar 21 Python
详细分析Python collections工具库
Jul 16 Python
对Keras自带Loss Function的深入研究
May 25 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
PHP分页显示制作详细讲解
2006/12/05 PHP
php strtotime 函数UNIX时间戳
2009/01/14 PHP
PHP显示今天、今月、上月、今年的起点/终点时间戳的代码
2011/05/25 PHP
10条PHP高级技巧[修正版]
2011/08/02 PHP
PHP安全防范技巧分享
2011/11/03 PHP
初窥JQuery-Jquery简介 入门了解篇
2010/11/25 Javascript
javascript处理table表格的代码
2010/12/06 Javascript
javascript ie6兼容position:fixed实现思路
2013/04/01 Javascript
JavaScript包装对象使用介绍
2013/08/29 Javascript
node.js中的buffer.write方法使用说明
2014/12/10 Javascript
jQuery Validate初步体验(一)
2015/12/12 Javascript
js获取Html元素的实际宽度高度的方法
2016/05/19 Javascript
微信页面倒计时代码(解决safari不兼容date的问题)
2016/12/13 Javascript
jQuery实现的checkbox级联选择下拉菜单效果示例
2016/12/26 Javascript
浅谈js script标签中的预解析
2016/12/30 Javascript
详解JS: reduce方法实现 webpack多文件入口
2017/02/14 Javascript
Koa2 之文件上传下载的示例代码
2018/03/29 Javascript
小程序从手动埋点到自动埋点的实现方法
2019/01/24 Javascript
jquery分页优化操作实例分析
2019/08/23 jQuery
vue项目中监听手机物理返回键的实现
2020/01/18 Javascript
nodejs各种姿势断点调试的方法
2020/06/18 NodeJs
一文秒懂JavaScript构造函数、实例、原型对象以及原型链
2020/08/25 Javascript
微信小程序使用前置摄像头拍照
2020/10/22 Javascript
[01:35]辉夜杯战队访谈宣传片—LGD
2015/12/25 DOTA
[45:16]完美世界DOTA2联赛PWL S3 Magma vs Phoenix 第一场 12.12
2020/12/16 DOTA
Python中optionParser模块的使用方法实例教程
2014/08/29 Python
Python获取Linux系统下的本机IP地址代码分享
2014/11/07 Python
Python实现进程同步和通信的方法
2018/01/02 Python
Python解决抛小球问题 求小球下落经历的距离之和示例
2018/02/01 Python
python3实现指定目录下文件sha256及文件大小统计
2019/02/25 Python
详解python解压压缩包的五种方法
2019/07/05 Python
Python实现微信机器人的方法
2019/09/06 Python
韩国保养品、日本药妆购物网:小三美日
2018/12/30 全球购物
植树节标语
2014/06/27 职场文书
船舶工程技术专业求职信
2014/08/07 职场文书
家具商场的活动方案
2014/08/16 职场文书