画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python图像处理之镜像实现方法
May 30 Python
全面解析Python的While循环语句的使用方法
Oct 13 Python
PYTHON 中使用 GLOBAL引发的一系列问题
Oct 12 Python
python中import学习备忘笔记
Jan 24 Python
python判断字符串是否是json格式方法分享
Nov 07 Python
Python对象中__del__方法起作用的条件详解
Nov 01 Python
VSCode Python开发环境配置的详细步骤
Feb 22 Python
在django中图片上传的格式校验及大小方法
Jul 28 Python
python3中numpy函数tile的用法详解
Dec 04 Python
python飞机大战pygame游戏框架搭建操作详解
Dec 17 Python
python实现双色球随机选号
Jan 01 Python
Django3中的自定义用户模型实例详解
Aug 23 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
php流量统计功能的实现代码
2012/09/29 PHP
php function用法如何递归及return和echo区别
2014/03/07 PHP
ThinkPHP做文字水印时提示call an undefined function exif_imagetype()解决方法
2014/10/30 PHP
php下Memcached入门实例解析
2015/01/05 PHP
php实现指定字符串中查找子字符串的方法
2015/03/17 PHP
Laravel 5.5 的自定义验证对象/类示例代码详解
2017/08/29 PHP
HTML中不支持静态Expando的元素的问题
2007/03/08 Javascript
Javascript 更新 JavaScript 数组的 uniq 方法
2008/01/23 Javascript
非主流的textarea自增长实现js代码
2011/12/20 Javascript
从数据结构分析看:用for each...in 比 for...in 要快些
2013/04/17 Javascript
JS localStorage实现本地缓存的方法
2013/06/22 Javascript
在js文件中如何获取basePath处理js路径问题
2013/07/10 Javascript
js实现收缩菜单效果实例代码
2013/10/30 Javascript
jQuery中的$.ajax()方法应用
2014/05/06 Javascript
从零学JSON之JSON数据结构
2014/05/19 Javascript
Javascript学习笔记之 对象篇(一) : 对象的使用和属性
2014/06/24 Javascript
jQuery on方法传递参数示例
2014/12/09 Javascript
javascript中slice(),splice(),split(),substring(),substr()使用方法
2015/03/13 Javascript
Node.js事件驱动
2015/06/18 Javascript
jQuery实现textarea自动增长宽高的方法
2015/12/18 Javascript
详解微信小程序开发之下拉刷新 上拉加载
2016/11/24 Javascript
node.js操作mongodb简单示例分享
2017/05/25 Javascript
angular.js和vue.js中实现函数去抖示例(debounce)
2018/01/18 Javascript
基于ionic实现下拉刷新功能
2018/05/10 Javascript
详解vue axios用post提交的数据格式
2018/08/07 Javascript
vue.js使用v-model实现表单元素(input) 双向数据绑定功能示例
2019/03/08 Javascript
JS检索下拉列表框中被选项目的索引号(selectedIndex)
2019/12/17 Javascript
Python里disconnect UDP套接字的方法
2015/04/23 Python
python中常见错误及解决方法
2020/06/21 Python
Python enumerate() 函数如何实现索引功能
2020/06/29 Python
python调用有道智云API实现文件批量翻译
2020/10/10 Python
用canvas做一个DVD待机动画的实现代码
2019/04/12 HTML / CSS
LORAC官网:美国彩妆品牌
2019/08/27 全球购物
客房服务员岗位职责
2015/02/09 职场文书
2016年党员承诺书范文
2016/03/24 职场文书
Python 数据可视化工具 Pyecharts 安装及应用
2022/04/20 Python