pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现封装得到virustotal扫描结果
Oct 05 Python
Python中实现从目录中过滤出指定文件类型的文件
Feb 02 Python
使用Python标准库中的wave模块绘制乐谱的简单教程
Mar 30 Python
Python里disconnect UDP套接字的方法
Apr 23 Python
asyncio 的 coroutine对象 与 Future对象使用指南
Sep 11 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 Python
python使用tornado实现简单爬虫
Jul 28 Python
python 保存float类型的小数的位数方法
Oct 17 Python
Python 元组操作总结
Sep 18 Python
python中提高pip install速度
Feb 14 Python
python 逐步回归算法
Apr 06 Python
python 使用tkinter与messagebox写界面和弹窗
Mar 20 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
Smarty模板快速入门
2007/01/04 PHP
PHP判断文件是否存在、是否可读、目录是否存在的代码
2012/10/03 PHP
ThinkPHP中使用ajax接收json数据的方法
2014/12/18 PHP
YII2框架中日志的配置与使用方法实例分析
2020/03/18 PHP
javascript 主动派发事件总结
2011/08/09 Javascript
jquery中的on方法使用介绍
2013/12/29 Javascript
Javascript动态引用CSS文件的2种方法介绍
2014/06/06 Javascript
javascript正则表达式参数/g与/i及/gi的使用指南
2014/08/27 Javascript
详解JavaScript的变量和数据类型
2015/11/27 Javascript
JavaScript面向对象程序设计教程
2016/03/29 Javascript
JS实现的表格行上下移动操作示例
2016/08/03 Javascript
javascript计算渐变颜色的实例
2017/09/22 Javascript
详解微信小程序Page中data数据操作和函数调用
2017/09/27 Javascript
通过说明与示例了解js五种设计模式
2019/06/17 Javascript
vue.js 打包时出现空白页和路径错误问题及解决方法
2019/06/26 Javascript
VUE子组件向父组件传值详解(含传多值及添加额外参数场景)
2020/09/01 Javascript
多个Vue项目部署到服务器的步骤记录
2020/10/22 Javascript
下载给定网页上图片的方法
2014/02/18 Python
详解Python的迭代器、生成器以及相关的itertools包
2015/04/02 Python
使用Python实现在Windows下安装Django
2018/10/17 Python
Python基于Logistic回归建模计算某银行在降低贷款拖欠率的数据示例
2019/01/23 Python
Python实现简单查找最长子串功能示例
2019/02/26 Python
python3+PyQt5 实现Rich文本的行编辑方法
2019/06/17 Python
python实现在线翻译功能
2020/03/03 Python
XD健身器材:Kevlar球、Crossfit健身球
2019/03/26 全球购物
湖南卫视在线视频媒体平台:芒果TV
2019/10/30 全球购物
求职信范文怎么写
2014/01/29 职场文书
个人党性剖析材料
2014/02/03 职场文书
《黄山奇石》教学反思
2014/04/19 职场文书
升旗仪式演讲稿
2014/05/08 职场文书
环保宣传标语
2014/06/12 职场文书
开展创先争优活动总结
2014/08/28 职场文书
合作合同协议书范本
2015/01/27 职场文书
仙境之桥观后感
2015/06/16 职场文书
大学生活感想
2015/08/10 职场文书
大学生社会服务心得体会
2016/01/22 职场文书