pytorch 权重weight 与 梯度grad 可视化操作


Posted in Python onJune 05, 2021

pytorch 权重weight 与 梯度grad 可视化

查看特定layer的权重以及相应的梯度信息

打印模型

pytorch 权重weight 与 梯度grad 可视化操作

观察到model下面有module的key,module下面有features的key, features下面有(0)的key,这样就可以直接打印出weight了

在pdb debug界面输入p model.module.features[0].weight,就可以看到weight,输入 p model.module.features[0].weight.grad 就可以查看梯度信息。

中间变量的梯度 : .register_hook

pytorch 为了节省显存,在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说,有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误,这个过程就需要用到 tensor的register_hook接口

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
x = torch.randn(1, requires_grad=True)
y = 3*x
z = y * y
# 为中间变量注册梯度保存接口,存储梯度时名字为 y。
y.register_hook(save_grad('y'))
# 反向传播 
z.backward()
# 查看 y 的梯度值
print(grads['y'])

打印网络回传梯度

net.named_parameters()

parms.requires_grad 表示该参数是否可学习,是不是frozen的;

parm.grad 打印该参数的梯度值。

net = your_network().cuda()
def train():
 ...
 outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
 for name, parms in net.named_parameters(): 
  print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
   ' -->grad_value:',parms.grad)

查看pytorch产生的梯度

[x.grad for x in self.optimizer.param_groups[0]['params']]

pytorch模型可视化及参数计算

我们在设计完程序以后希望能对我们的模型进行可视化,pytorch这里似乎没有提供相应的包直接进行调用,参考一些博客。

下面把代码贴出来:

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
def make_dot(var, params=None):
   
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}
 
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
 
    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'
 
    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

我们在我们的模型下面直接进行调用就可以了,例如:

if __name__ == "__main__":
    model = DeepLab(backbone='resnet', output_stride=16)
    input = torch.rand(1, 3, 53, 53)
    output = model(input)
    g = make_dot(output)
    g.view()
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        print("该层的结构:" + str(list(i.size())))
        for j in i.size():
            l *= j
        print("该层参数和:" + str(l))
        k = k + l
    print("总参数数量和:" + str(k))

模型部分可视化结果:

pytorch 权重weight 与 梯度grad 可视化操作

参数计算:

pytorch 权重weight 与 梯度grad 可视化操作

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python命令行参数解析模块optparse使用实例
Apr 13 Python
Python使用文件锁实现进程间同步功能【基于fcntl模块】
Oct 16 Python
Python用户推荐系统曼哈顿算法实现完整代码
Dec 01 Python
python将视频转换为全字符视频
Apr 26 Python
Python常用模块logging——日志输出功能(示例代码)
Nov 20 Python
python实现随机加减法生成器
Feb 24 Python
python软件都是免费的吗
Jun 18 Python
python读取excel进行遍历/xlrd模块操作
Jul 12 Python
Python 实现微信自动回复的方法
Sep 11 Python
python 实现定时任务的四种方式
Apr 01 Python
90行Python代码开发个人云盘应用
Apr 20 Python
Flask response响应的具体使用
Jul 15 Python
PyTorch 如何检查模型梯度是否可导
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
You might like
php 文件状态缓存带来的问题
2008/12/14 PHP
PHP 裁剪图片成固定大小代码方法
2009/09/09 PHP
PHP编码转换函数 自动转换字符集支持数组转换
2012/12/16 PHP
php设置session值和cookies的学习示例
2014/03/21 PHP
php正则表达式验证(邮件地址、Url地址、电话号码、邮政编码)
2016/03/14 PHP
php常用图片处理类
2016/03/16 PHP
javascript显示选择目录对话框的代码
2008/11/10 Javascript
jquery 弹出层注册页面等(asp.net后台)
2010/06/17 Javascript
用js实现小球的自由移动代码
2013/04/22 Javascript
JQuery中对Select的option项的添加、删除、取值
2013/08/25 Javascript
JavaScript调用ajax获取文本文件内容实现代码
2014/03/28 Javascript
node.js实现BigPipe详解
2014/12/05 Javascript
jquery拖动层效果插件用法实例分析(附demo源码)
2016/04/28 Javascript
js中作用域的实例解析
2017/03/16 Javascript
探索webpack模块及webpack3新特性
2017/09/18 Javascript
Vue 监听列表item渲染事件方法
2018/09/06 Javascript
微信小程序云开发之数据库操作
2019/05/18 Javascript
[56:45]DOTA2上海特级锦标赛D组小组赛#1 EG VS COL第一局
2016/02/28 DOTA
[50:17]Newbee vs Serenity 2018国际邀请赛小组赛BO2 第二场 8.17
2018/08/18 DOTA
[52:02]DOTA2-DPC中国联赛 正赛 Phoenix vs Dragon BO3 第二场 2月26日
2021/03/11 DOTA
在Python的Django框架中实现Hacker News的一些功能
2015/04/17 Python
python下实现二叉堆以及堆排序的示例
2017/09/29 Python
Python实现多属性排序的方法
2018/12/05 Python
Python3字符串encode与decode的讲解
2019/04/02 Python
python导入坐标点的具体操作
2019/05/10 Python
python多线程分块读取文件
2019/08/29 Python
python实现滑雪者小游戏
2020/02/22 Python
小 200 行 Python 代码制作一个换脸程序
2020/05/12 Python
keras model.fit 解决validation_spilt=num 的问题
2020/06/19 Python
Rowdy Gentleman服装和配饰:美好时光
2019/09/24 全球购物
学前教育毕业生自荐信
2013/10/29 职场文书
汽车专业学生自我评价
2014/01/19 职场文书
优秀毕业生求职信
2014/06/05 职场文书
党员三严三实对照检查材料
2014/10/13 职场文书
2016初一新生军训心得体会
2016/01/11 职场文书
开网店计划分析
2019/07/30 职场文书