pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python将人民币转换大写的脚本代码
Feb 10 Python
Python实现的简单文件传输服务器和客户端
Apr 08 Python
python使用in操作符时元组和数组的区别分析
May 19 Python
使用Python来开发Markdown脚本扩展的实例分享
Mar 04 Python
基于Django与ajax之间的json传输方法
May 29 Python
python脚本当作Linux中的服务启动实现方法
Jun 28 Python
阿里云ECS服务器部署django的方法
Aug 29 Python
python并发爬虫实用工具tomorrow实用解析
Sep 25 Python
Python3标准库之threading进程中管理并发操作方法
Mar 30 Python
基于python爬取有道翻译过程图解
Mar 31 Python
Python和Bash结合在一起的方法
Nov 13 Python
pytorch常用数据类型所占字节数对照表一览
May 17 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
php socket通信(tcp/udp)实例分析
2016/02/14 PHP
DIY jquery plugin - tabs标签切换实现代码
2010/12/11 Javascript
jquery 插件学习(二)
2012/08/06 Javascript
使用简洁的jQuery方法实现隔行换色功能
2014/01/02 Javascript
javascript计算当月剩余天数(天数计算器)示例代码
2014/01/09 Javascript
jQuery中map()方法用法实例
2015/01/06 Javascript
javascript字符串与数组转换汇总
2015/05/26 Javascript
JS版元素周期表实现方法
2015/08/05 Javascript
使用ES6语法重构React代码详解
2017/05/09 Javascript
js使用xml数据载体实现城市省份二级联动效果
2017/11/08 Javascript
angular2组件中定时刷新并清除定时器的实例讲解
2018/08/31 Javascript
实用Javascript调试技巧分享(小结)
2019/06/18 Javascript
jQuery实现文本显示一段时间后隐藏的方法分析
2019/06/20 jQuery
JavaScript跳出循环的三种方法(break, return, continue)
2019/07/30 Javascript
Vue使用vue-recoure + http-proxy-middleware + vuex配合promise实现基本的跨域请求封装
2019/10/21 Javascript
js判断一个对象是数组(函数)的方法实例
2019/12/19 Javascript
vue3.0搭配.net core实现文件上传组件
2020/10/29 Javascript
[43:47]完美世界DOTA2联赛PWL S3 LBZS vs Phoenix 第一场 12.09
2020/12/11 DOTA
python getopt详解及简单实例
2016/12/30 Python
python pandas 如何替换某列的一个值
2018/06/09 Python
使用Python微信库itchat获得好友和群组已撤回的消息
2018/06/24 Python
解决nohup执行python程序log文件写入不及时的问题
2019/01/14 Python
python小程序实现刷票功能详解
2019/07/17 Python
Django 开发调试工具 Django-debug-toolbar使用详解
2019/07/23 Python
Python中生成ndarray实例讲解
2021/02/22 Python
HearthSong官网:儿童户外玩具、儿童益智玩具
2017/10/16 全球购物
计算机专业个人求职信范例
2013/09/23 职场文书
临床医学专业学生的自我评价分享
2013/11/21 职场文书
赔偿协议书范本
2014/04/15 职场文书
群众路线专项整治工作情况报告
2014/10/28 职场文书
先进教育工作者事迹材料
2014/12/23 职场文书
个人委托书范文
2015/01/28 职场文书
惹女朋友生气检讨书
2015/05/06 职场文书
我是特种兵观后感
2015/06/11 职场文书
python学习之panda数据分析核心支持库
2021/05/07 Python
JS的深浅复制详细
2021/10/16 Javascript