pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现过滤单个Android程序日志脚本分享
Jan 16 Python
python开发之基于thread线程搜索本地文件的方法
Nov 11 Python
python文件操作相关知识点总结整理
Feb 22 Python
Python操作json的方法实例分析
Dec 06 Python
详解pyinstaller selenium python3 chrome打包问题
Oct 18 Python
python实现二分类的卡方分箱示例
Nov 22 Python
Tensorflow分批量读取数据教程
Feb 07 Python
Python类中的装饰器在当前类中的声明与调用详解
Apr 15 Python
Python Tornado核心及相关原理详解
Jun 24 Python
python闭包与引用以及需要注意的陷阱
Sep 18 Python
如何通过python计算圆周率PI
Nov 11 Python
python中numpy数组与list相互转换实例方法
Jan 29 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
PHP令牌 Token改进版
2008/07/18 PHP
PHP中使用gettext来支持多语言的方法
2011/05/02 PHP
php json_encode()函数返回json数据实例代码
2014/10/10 PHP
php查找指定目录下指定大小文件的方法
2014/11/28 PHP
使用php转义输出HTML到JavaScript
2015/03/27 PHP
typecho插件编写教程(五):核心代码
2015/05/28 PHP
php微信开发之关注事件
2018/06/14 PHP
php文件后缀不强制为.php的实操方法
2019/09/18 PHP
document.getElementBy("id")与$("#id")有什么区别
2013/09/22 Javascript
阻止事件(取消浏览器对事件的默认行为并阻止其传播)
2013/11/03 Javascript
javascript的渐进增强与平稳退化浅谈
2013/11/12 Javascript
动态读取JSON解析键值对的方法
2014/06/03 Javascript
jQuery移动页面开发中的触摸事件与虚拟鼠标事件简介
2015/12/03 Javascript
JavaScript事件 "事件对象"的注意要点
2016/01/14 Javascript
jQuery实现动态添加tr到table的方法
2016/12/26 Javascript
基于vue2.0实现的级联选择器
2017/06/09 Javascript
微信小程序canvas写字板效果及实例
2017/06/15 Javascript
jQuery实现动态添加节点与遍历节点功能示例
2017/11/09 jQuery
Nodejs中的JWT和Session的使用
2018/08/21 NodeJs
[00:12]DAC2018 no[o]ne亮相SOLO赛 他是否如他的id一样无人可挡?
2018/04/06 DOTA
[39:53]完美世界DOTA2联赛PWL S2 LBZS vs Forest 第一场 11.19
2020/11/19 DOTA
Python中exit、return、sys.exit()等使用实例和区别
2015/05/28 Python
pyenv命令管理多个Python版本
2017/03/26 Python
Python调用ctypes使用C函数printf的方法
2017/08/23 Python
python绘制评估优化算法性能的测试函数
2019/06/25 Python
Spring实战之使用util:命名空间简化配置操作示例
2019/12/09 Python
python音频处理的示例详解
2020/12/23 Python
HTML5上传文件显示进度的实现代码
2012/08/30 HTML / CSS
Speedo美国:澳大利亚顶尖泳衣制造商
2016/08/03 全球购物
世界上最大的专业美容用品零售商:Sally Beauty
2017/07/02 全球购物
Reebok官方旗舰店:美国知名健身品牌锐步
2019/01/07 全球购物
什么是反射
2012/03/17 面试题
营销人才自我鉴定范文
2013/12/25 职场文书
不假外出检讨书
2014/01/27 职场文书
保险专业大学生职业规划书
2014/03/03 职场文书
网上祭先烈心得体会
2014/09/01 职场文书