pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过pil模块获得图片exif信息的方法
Mar 16 Python
使用Python中的线程进行网络编程的入门教程
Apr 15 Python
Python面向对象之类和对象属性的增删改查操作示例
Dec 14 Python
利用python如何在前程无忧高效投递简历
May 07 Python
Python流程控制 if else实现解析
Sep 02 Python
python 定义类时,实现内部方法的互相调用
Dec 25 Python
浅谈Pycharm最有必要改的几个默认设置项
Feb 14 Python
python读取当前目录下的CSV文件数据
Mar 11 Python
详解Python3中的 input() 函数
Mar 18 Python
对python pandas中 inplace 参数的理解
Jun 27 Python
Python 列表推导式需要注意的地方
Oct 23 Python
python实现企业微信定时发送文本消息的实例代码
Nov 25 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
浅析php数据类型转换
2014/01/09 PHP
注意!PHP 7中不要做的10件事
2016/09/18 PHP
PHP实现二叉树深度优先遍历(前序、中序、后序)和广度优先遍历(层次)实例详解
2018/04/20 PHP
基于PHP实现短信验证码发送次数限制
2020/07/11 PHP
Jquery 实现Tab效果 思路是js思路
2010/03/02 Javascript
javascript数组去重3种方法的性能测试与比较
2013/03/26 Javascript
今天是星期几的4种JS代码写法
2013/09/17 Javascript
JQuery文本改变触发事件如聚焦事件、失焦事件
2014/01/15 Javascript
javascript实现鼠标移到Image上方时显示文字效果的方法
2015/08/07 Javascript
jQuery监听文件上传实现进度条效果的方法
2016/10/16 Javascript
原生js和css实现图片轮播效果
2017/02/07 Javascript
微信小程序 引用其他js文件实现代码
2017/02/22 Javascript
从零学习node.js之详解异步控制工具async(八)
2017/02/27 Javascript
完美解决浏览器跨域的几种方法(汇总)
2017/05/08 Javascript
jQuery实现用户信息表格的添加和删除功能
2017/09/12 jQuery
Vue项目webpack打包部署到Tomcat刷新报404错误问题的解决方案
2018/05/15 Javascript
原来JS还可以这样拆箱转换详解
2019/02/01 Javascript
[49:18]2018DOTA2亚洲邀请赛 3.31 小组赛 A组 OG vs TNC
2018/04/01 DOTA
Python中Random和Math模块学习笔记
2015/05/18 Python
python使用wxpython开发简单记事本的方法
2015/05/20 Python
Python文件读取的3种方法及路径转义
2015/06/21 Python
asyncio 的 coroutine对象 与 Future对象使用指南
2016/09/11 Python
python 获取当天凌晨零点的时间戳方法
2018/05/22 Python
python 图像的离散傅立叶变换实例
2020/01/02 Python
python实现滑雪游戏
2020/02/22 Python
python支持多继承吗
2020/06/19 Python
Python join()函数原理及使用方法
2020/11/14 Python
Selenium环境变量配置(火狐浏览器)及验证实现
2020/12/07 Python
HTML5的Geolocation地理位置定位API使用教程
2016/05/12 HTML / CSS
蔻驰美国官网:COACH美国
2016/08/18 全球购物
Gibson London官网:以地道的英国男装而著称
2019/12/06 全球购物
分公司任命书
2014/06/06 职场文书
茶花女读书笔记
2015/06/29 职场文书
Python爬虫:从m3u8文件里提取小视频的正确操作
2021/05/14 Python
MySQL提升大量数据查询效率的优化神器
2022/07/07 MySQL
table设置超出部分隐藏,鼠标移上去显示全部内容的方法
2022/12/24 HTML / CSS