画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中关于日期时间处理的问答集锦
Mar 08 Python
python插入排序算法的实现代码
Nov 21 Python
在Python编程过程中用单元测试法调试代码的介绍
Apr 02 Python
微信跳一跳游戏python脚本
Apr 01 Python
Python工程师面试必备25条知识点
Jan 17 Python
通过python爬虫赚钱的方法
Jan 29 Python
Python中list循环遍历删除数据的正确方法
Sep 02 Python
python GUI库图形界面开发之PyQt5信号与槽多窗口数据传递详细使用方法与实例
Mar 08 Python
python opencv角点检测连线功能的实现代码
Nov 24 Python
python调用百度AI接口实现人流量统计
Feb 03 Python
利用Python判断你的密码难度等级
Jun 02 Python
python开发飞机大战游戏
Jul 15 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
php的list()的一步操作给一组变量进行赋值的使用
2011/05/18 PHP
网站防止被刷票的一些思路与方法
2015/01/08 PHP
PHP数组操作实例分析【添加,删除,计算,反转,排序,查找等】
2016/12/24 PHP
php实现登录页面的简单实例
2019/09/29 PHP
php实现JWT(json web token)鉴权实例详解
2019/11/05 PHP
ie和firefox中img对象区别的困惑
2006/12/27 Javascript
Mootools 1.2教程 滑动效果(Slide)
2009/09/15 Javascript
jquery序列化表单去除指定元素示例代码
2014/04/10 Javascript
js时钟翻牌效果实现代码分享
2020/07/31 Javascript
JavaScript位移运算符(无符号) >>> 三个大于号 的使用方法详解
2016/03/31 Javascript
json传值以及ajax接收详解
2016/05/24 Javascript
jQuery实现的多张图无缝滚动效果【测试可用】
2016/09/12 Javascript
浅谈移动端之js touch事件 手势滑动事件
2016/11/07 Javascript
nodejs调取微信收货地址的方法
2017/12/20 NodeJs
jquery实现商品sku多属性选择功能(商品详情页)
2019/12/20 jQuery
JS实现鼠标移动拖尾
2020/12/27 Javascript
全面解读Python Web开发框架Django
2014/06/30 Python
python循环监控远程端口的方法
2015/03/14 Python
基于Linux系统中python matplotlib画图的中文显示问题的解决方法
2017/06/15 Python
浅谈Python中的zip()与*zip()函数详解
2018/02/24 Python
Python matplotlib的使用并自定义colormap的方法
2018/12/13 Python
python FTP批量下载/删除/上传实例
2019/12/22 Python
TensorFlow打印输出tensor的值
2020/04/19 Python
一文轻松掌握python语言命名规范规则
2020/06/18 Python
Opencv 图片的OCR识别的实战示例
2021/03/02 Python
css 元素选择器的简单实例
2016/05/23 HTML / CSS
商场促销活动方案
2014/02/08 职场文书
主办会计岗位职责
2014/03/13 职场文书
酒店开业庆典主持词
2014/03/21 职场文书
幼儿园中班上学期评语
2014/04/18 职场文书
教师敬业奉献模范事迹材料
2014/05/18 职场文书
化学专业毕业生求职信
2014/07/28 职场文书
停车场管理制度范本
2015/08/05 职场文书
给领导敬酒词
2015/08/12 职场文书
python自动化之如何利用allure生成测试报告
2021/05/02 Python
Mysql数据库事务的脏读幻读及不可重复读详解
2022/05/30 MySQL