pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现从脚本里运行scrapy的方法
Apr 07 Python
Python 爬虫爬取指定博客的所有文章
Feb 17 Python
用yum安装MySQLdb模块的步骤方法
Dec 15 Python
Python基础教程之异常详解
Jan 10 Python
Python基础知识点 初识Python.md
May 14 Python
在python中用print()输出多个格式化参数的方法
Jul 16 Python
使用django的objects.filter()方法匹配多个关键字的方法
Jul 18 Python
python反转列表的三种方式解析
Nov 08 Python
PyTorch中反卷积的用法详解
Dec 30 Python
python新手学习可变和不可变对象
Jun 11 Python
Python实现简单猜数字游戏
Feb 03 Python
基于PyQt5制作一个群发邮件工具
Apr 08 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
如何在PHP中使用Oracle数据库(5)
2006/10/09 PHP
php 破解防盗链图片函数
2008/12/09 PHP
php下pdo的mysql事务处理用法实例
2014/12/27 PHP
深入理解PHP内核(一)
2015/11/10 PHP
PHP将二维数组某一个字段相同的数组合并起来的方法
2016/02/26 PHP
php版微信支付api.mch.weixin.qq.com域名解析慢原因与解决方法
2016/10/12 PHP
详解PHP中的外观模式facade pattern
2018/02/05 PHP
PHP微信H5支付开发实例
2018/07/25 PHP
Ajax+PHP实现的删除数据功能示例
2019/02/12 PHP
javascript中IE浏览器不支持NEW DATE()带参数的解决方法
2012/03/01 Javascript
实用的JS正则表达式(手机号码/IP正则/邮编正则/电话等)
2013/01/11 Javascript
jquery 延迟执行实例介绍
2013/08/20 Javascript
Js中的onblur和onfocus事件应用介绍
2013/08/27 Javascript
javascript简单性能问题及学习笔记
2014/02/04 Javascript
用js的document.write输出的广告无阻塞加载的方法
2014/06/05 Javascript
jQuery移动页面开发中的触摸事件与虚拟鼠标事件简介
2015/12/03 Javascript
JavaScript弹出对话框的三种方式
2016/03/23 Javascript
使用js获取地址栏参数的方法推荐(超级简单)
2016/06/14 Javascript
js装饰设计模式学习心得
2018/02/17 Javascript
原生JS+HTML5实现跟随鼠标一起流动的粒子动画效果
2018/05/03 Javascript
详解Vue+ElementUI从零开始搭建自己的网站(一、环境搭建)
2019/04/30 Javascript
js调用网络摄像头的方法
2020/12/05 Javascript
[03:59]DOTA2英雄梦之声_第07期_水晶室女
2014/06/23 DOTA
[04:46]2018年度玩家喜爱的电竞媒体-完美盛典
2018/12/16 DOTA
Python的Django REST框架中的序列化及请求和返回
2016/04/11 Python
NumPy.npy与pandas DataFrame的实例讲解
2018/07/09 Python
对python判断ip是否可达的实例详解
2019/01/31 Python
python实现数字炸弹游戏程序
2020/07/17 Python
解决python和pycharm安装gmpy2 出现ERROR的问题
2020/08/28 Python
基于Python-Pycharm实现的猴子摘桃小游戏(源代码)
2021/02/20 Python
Structs界面控制层技术
2013/10/11 面试题
审计主管岗位职责
2014/01/31 职场文书
应届生求职信范文
2014/05/26 职场文书
农村优秀教师事迹材料
2014/08/27 职场文书
会议欢迎词
2015/01/23 职场文书
学习习近平主席讲话心得体会
2016/01/20 职场文书