pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python开发之thread线程基础实例入门
Nov 11 Python
Python面向对象特殊成员
Apr 24 Python
Python中property函数用法实例分析
Jun 04 Python
python 图像平移和旋转的实例
Jan 10 Python
python scrapy爬虫代码及填坑
Aug 12 Python
Series和DataFrame使用简单入门
Nov 13 Python
详解centos7+django+python3+mysql+阿里云部署项目全流程
Nov 15 Python
如何使用Python脚本实现文件拷贝
Nov 20 Python
python Opencv计算图像相似度过程解析
Dec 03 Python
python字符串下标与切片及使用方法
Feb 13 Python
详解用Python爬虫获取百度企业信用中企业基本信息
Jul 02 Python
浅谈matplotlib默认字体设置探索
Feb 03 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
PHP 服务器配置(使用Apache及IIS两种方法)
2009/06/01 PHP
php格式输出文件var_export函数实例
2014/11/15 PHP
php导出csv文件,可导出前导0实例代码
2016/11/16 PHP
javascript 获取元素位置的快速方法 getBoundingClientRect()
2009/11/26 Javascript
js 文件引入实现代码
2010/04/23 Javascript
javascript 闭包
2011/09/15 Javascript
window.parent与window.openner区别介绍
2012/04/12 Javascript
js浮点数保留两位小数点示例代码(四舍五入)
2013/12/26 Javascript
JavaScript利用构造函数和原型的方式模拟C#类的功能
2014/03/06 Javascript
js实现的类似于asp数据字典的数据类型代码实例
2014/09/03 Javascript
JavaScript中数据结构与算法(三):链表
2015/06/19 Javascript
JavaScript获得url查询参数的方法
2015/07/02 Javascript
JavaScript 封装一个tab效果源码分享
2015/09/15 Javascript
javascript设计简单的秒表计时器
2020/09/05 Javascript
Angular2.0/4.0 使用Echarts图表的示例代码
2017/12/07 Javascript
Express下采用bcryptjs进行密码加密的方法
2018/02/07 Javascript
在vue里面设置全局变量或数据的方法
2018/03/09 Javascript
iview中Select 选择器多选校验方法
2018/03/15 Javascript
对angular4子路由&辅助路由详解
2018/10/09 Javascript
探索JavaScript中私有成员的相关知识
2019/06/13 Javascript
JS开发自己的类库实例分析
2019/08/28 Javascript
微信小程序自定义tabBar在uni-app的适配详解
2019/09/30 Javascript
如何基于js判断浏览器版本
2020/02/20 Javascript
[15:57]教你分分钟做大人:斧王
2014/10/30 DOTA
[57:24]LGD vs VGJ.T 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
跟老齐学Python之总结参数的传递
2014/10/10 Python
Python简单的制作图片验证码实例
2017/05/31 Python
对python 判断数字是否小于0的方法详解
2019/01/26 Python
使用python+whoosh实现全文检索
2019/12/09 Python
HTML5移动端手机网站开发流程
2016/04/25 HTML / CSS
美国著名的婴儿学步鞋老品牌:Robeez
2016/08/20 全球购物
夏洛特和乔治婴儿和儿童时装精品店:Charlotte and George
2018/06/06 全球购物
名词解释型面试题(主要是网络)
2013/12/27 面试题
鼋头渚导游词
2015/02/05 职场文书
村官个人总结范文
2015/03/03 职场文书
java设计模式--七大原则详解
2021/07/21 Java/Android