画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的zipfile模块使用详解
Jun 25 Python
Python脚本获取操作系统版本信息
Dec 17 Python
Python线性方程组求解运算示例
Jan 17 Python
单利模式及python实现方式详解
Mar 20 Python
Python实现聊天机器人的示例代码
Jul 09 Python
django框架使用orm实现批量更新数据的方法
Jun 21 Python
python求加权平均值的实例(附纯python写法)
Aug 22 Python
python统计字符串中字母出现次数代码实例
Mar 02 Python
python实现简易版学生成绩管理系统
Jun 22 Python
实例代码讲解Python 线程池
Aug 24 Python
python实现视频压缩功能
Dec 18 Python
python爬虫利用代理池更换IP的方法步骤
Feb 21 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
Extjs入门之动态加载树代码
2010/04/09 Javascript
学习javascript,实现插入排序实现代码
2011/07/31 Javascript
silverlight线程与基于事件驱动javascript引擎(实现轨迹回放功能)
2011/08/09 Javascript
javascript实现checkbox全选的代码
2015/04/30 Javascript
JavaScript添加随滚动条滚动窗体的方法
2016/02/23 Javascript
jQuery web 组件 后台日历价格、库存设置的代码
2016/10/14 Javascript
jQuery插件HighCharts实现的2D堆条状图效果示例【附demo源码下载】
2017/03/14 Javascript
Vue.js实现模拟微信朋友圈开发demo
2017/04/20 Javascript
jQuery选择器特殊字符与属性空格问题
2017/08/14 jQuery
vue中通过使用$attrs实现组件之间的数据传递功能
2019/09/01 Javascript
JS立即执行的匿名函数用法分析
2019/11/04 Javascript
浅谈Webpack4 Tree Shaking 终极优化指南
2019/11/18 Javascript
vue+element实现动态加载表单
2020/12/13 Vue.js
python提取页面内url列表的方法
2015/05/25 Python
python中plot实现即时数据动态显示方法
2018/06/22 Python
Python3爬虫学习之MySQL数据库存储爬取的信息详解
2018/12/12 Python
python中pytest收集用例规则与运行指定用例详解
2019/06/27 Python
Python中拆分字符串的操作方法
2019/07/23 Python
解决Python logging模块无法正常输出日志的问题
2020/02/21 Python
Sisley法国希思黎中国官网:享誉全球的奢华植物美容品牌
2019/06/30 全球购物
linux面试题参考答案(8)
2015/08/11 面试题
实习生自荐信范文分享
2013/11/27 职场文书
自动化专业个人求职信范文
2013/11/29 职场文书
学生出入校管理制度
2014/01/16 职场文书
办公室秘书岗位职责范本
2014/02/11 职场文书
成人继续教育实施方案
2014/03/01 职场文书
基层党组织公开承诺书
2014/03/28 职场文书
12岁生日演讲稿
2014/05/14 职场文书
节能环保口号
2014/06/12 职场文书
中学清明节活动总结
2014/07/04 职场文书
通知书大全
2015/04/27 职场文书
农业项目投资意向书
2015/05/09 职场文书
婚育证明样本
2015/06/16 职场文书
2019各种承诺书范文
2019/06/24 职场文书
JVM钩子函数的使用场景详解
2021/08/23 Java/Android
JavaScript最完整的深浅拷贝实现方式详解
2022/02/28 Javascript