pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的HTTP并发测试完整示例
Apr 23 Python
Python subprocess模块常见用法分析
Jun 12 Python
Python2包含中文报错的解决方法
Jul 09 Python
Python中使用logging和traceback模块记录日志和跟踪异常
Apr 09 Python
Python文件时间操作步骤代码详解
Apr 13 Python
解决Jupyter NoteBook输出的图表太小看不清问题
Apr 16 Python
python之pygame模块实现飞机大战完整代码
Nov 29 Python
Python读写Excel表格的方法
Mar 02 Python
python实现黄金分割法的示例代码
Apr 28 Python
Python如何配置环境变量详解
May 18 Python
Python趣味实战之手把手教你实现举牌小人生成器
Jun 07 Python
Python使用MapReduce进行简单的销售统计
Apr 22 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
PHP如何编写易读的代码
2007/07/10 PHP
提高PHP编程效率 引入缓存机制提升性能
2010/02/15 PHP
PHP 文本文章分页代码 按标记或长度(不涉及数据库)
2012/06/07 PHP
php编写的抽奖程序中奖概率算法
2015/05/14 PHP
Input 特殊事件onpopertychange和oninput
2009/06/17 Javascript
JavaScript 学习笔记 Black.Caffeine 09.11.28
2009/11/30 Javascript
有趣的javascript数组定义方法
2010/09/10 Javascript
jQuery遍历Form示例代码
2013/09/03 Javascript
js跳转页面方法实现汇总
2014/02/11 Javascript
用js判断输入是否为中文的函数
2014/03/10 Javascript
浅析jQuery Ajax通用js封装
2016/06/22 Javascript
js按条件生成随机json:randomjson实现方法
2017/04/07 Javascript
Angular中的interceptors拦截器
2017/06/25 Javascript
vue2实现可复用的轮播图carousel组件详解
2017/11/27 Javascript
详解JavaScript基础知识(JSON、Function对象、原型、引用类型)
2018/01/16 Javascript
ajax请求+vue.js渲染+页面加载的示例
2018/02/11 Javascript
详解js跨域请求的两种方式,支持post请求
2018/05/05 Javascript
JavaScript new对象的四个过程实例浅析
2018/07/31 Javascript
微信小程序实现选项卡效果
2018/11/06 Javascript
简谈创建React Component的几种方式
2019/06/15 Javascript
vue实现短信验证码输入框
2020/04/17 Javascript
实现一个Vue自定义指令懒加载的方法示例
2020/06/04 Javascript
微信小程序自定义胶囊样式
2020/12/27 Javascript
Python的Tornado框架的异步任务与AsyncHTTPClient
2016/06/27 Python
基于Django contrib Comments 评论模块(详解)
2017/12/08 Python
Python基础教程之异常详解
2019/01/10 Python
Python实战之制作天气查询软件
2019/05/14 Python
python 实现图像快速替换某种颜色
2020/06/04 Python
pandas之分组groupby()的使用整理与总结
2020/06/18 Python
基于Python爬取51cto博客页面信息过程解析
2020/08/25 Python
python快速安装OpenCV的步骤记录
2021/02/22 Python
10分钟理解CSS3 FlexBox弹性布局
2018/12/20 HTML / CSS
2014年销售部工作总结
2014/12/01 职场文书
2015年初中教师个人工作总结
2015/07/21 职场文书
Mysql systemctl start mysqld报错的问题解决
2021/06/03 MySQL
Python爬虫入门案例之爬取去哪儿旅游景点攻略以及可视化分析
2021/10/16 Python