pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Linux下编译安装MySQL-Python教程
Feb 02 Python
简单介绍Python的Django框架加载模版的方式
Jul 20 Python
Puppeteer使用示例详解
Jun 20 Python
python3模拟实现xshell远程执行liunx命令的方法
Jul 12 Python
Pandas操作CSV文件的读写实现方法
Nov 13 Python
python设置代理和添加镜像源的方法
Feb 14 Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 Python
Keras中的多分类损失函数用法categorical_crossentropy
Jun 11 Python
基于python判断字符串括号是否闭合{}[]()
Sep 21 Python
python小程序之飘落的银杏
Apr 17 Python
python中的装饰器该如何使用
Jun 18 Python
Python 中的 copy()和deepcopy()
Nov 07 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
建立动态的WML站点(二)
2006/10/09 PHP
window+nginx+php环境配置 附配置搭配说明
2010/12/29 PHP
PHP内核探索:变量概述
2014/01/30 PHP
php动态函数调用方法
2015/05/21 PHP
PHP curl使用实例
2015/07/02 PHP
Android AsyncTack 异步任务实例详解
2016/11/02 PHP
PHP搭建大文件切割分块上传功能示例
2017/01/04 PHP
PHP使用finfo_file()函数检测上传图片类型的实现方法
2017/04/18 PHP
PHP html_entity_decode()函数讲解
2019/02/25 PHP
通过身份证号得到出生日期和性别的js代码
2009/11/23 Javascript
Javascript中判断变量是数组还是对象(array还是object)
2013/08/14 Javascript
JavaScript onkeydown事件入门实例(键盘某个按键被按下)
2014/10/17 Javascript
轻松创建nodejs服务器(3):代码模块化
2014/12/18 NodeJs
JavaScript中用于四舍五入的Math.round()方法讲解
2015/06/15 Javascript
jQuery实现的fixedMenu下拉菜单效果代码
2015/08/24 Javascript
JavaScript判断手机号运营商是移动、联通、电信还是其他(代码简单)
2015/09/25 Javascript
Jquery中巧用Ajax的beforeSend方法
2016/01/20 Javascript
JavaScript必知必会(十) call apply bind的用法说明
2016/06/08 Javascript
jQuery实现可展开折叠的导航效果示例
2016/09/12 Javascript
javascript之with的使用(阿里云、淘宝使用代码分析)
2016/10/11 Javascript
使用JavaScript判断用户输入的是否为正整数(两种方法)
2017/02/05 Javascript
微信通过页面(H5)直接打开本地app的解决方法
2017/09/09 Javascript
Vue插件之滑动验证码
2019/09/21 Javascript
python 运算符 供重载参考
2009/06/11 Python
Python利用matplotlib生成图片背景及图例透明的效果
2017/04/27 Python
在Python中使用AOP实现Redis缓存示例
2017/07/11 Python
Python中的默认参数实例分析
2018/01/29 Python
Django1.9 加载通过ImageField上传的图片方法
2018/05/25 Python
在unittest中使用 logging 模块记录测试数据的方法
2018/11/30 Python
Python完成哈夫曼树编码过程及原理详解
2019/07/29 Python
python常用排序算法的实现代码
2019/11/08 Python
1688平价精选商城:阿里集团旗下,工厂出厂价格直销
2017/04/24 全球购物
中学生操行评语
2014/04/24 职场文书
家具商场的活动方案
2014/08/16 职场文书
我的中国梦演讲稿小学篇
2014/08/19 职场文书
pytorch 实现在测试的时候启用dropout
2021/05/27 Python