画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之集成开发环境(IDE)
Sep 12 Python
python matplotlib 在指定的两个点之间连线方法
May 25 Python
详解Python计算机视觉 图像扭曲(仿射扭曲)
Mar 27 Python
Django 开发环境配置过程详解
Jul 18 Python
python lambda表达式(匿名函数)写法解析
Sep 16 Python
Python调用Windows API函数编写录音机和音乐播放器功能
Jan 05 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 Python
python数据库操作mysql:pymysql、sqlalchemy常见用法详解
Mar 30 Python
有趣的Python图片制作之如何用QQ好友头像拼接出里昂
Apr 22 Python
浅谈Python中的继承
Jun 19 Python
Python容器类型公共方法总结
Aug 19 Python
python Matplotlib数据可视化(1):简单入门
Sep 30 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
php 服务器调试 Zend Debugger 的安装教程
2009/09/25 PHP
PHP实现的增强性mhash函数
2015/05/27 PHP
Nginx实现反向代理
2017/09/20 Servers
laravel5.1框架model类查询的实现方法
2019/10/08 PHP
Thinkphp5.0框架使用模型Model的获取器、修改器、软删除数据操作示例
2019/10/11 PHP
php使用redis的有序集合zset实现延迟队列应用示例
2020/02/20 PHP
innerHTML,outerHTML,innerTEXT三者之间的区别
2007/01/28 Javascript
javascript 获取图片颜色
2009/04/05 Javascript
utf-8编码引起js输出中文乱码的解决办法
2010/06/23 Javascript
JavaScript 字符串处理函数使用小结
2010/12/02 Javascript
js DOM的学习笔记
2011/12/22 Javascript
js窗口关闭提示信息(兼容IE和firefox)
2015/10/23 Javascript
基于jquery编写分页插件
2016/03/07 Javascript
Javascript之面向对象--封装
2016/12/02 Javascript
jQuery插件JWPlayer视频播放器用法实例分析
2017/01/11 Javascript
jQuery实现注册会员时密码强度提示信息功能示例
2017/09/05 jQuery
layuiAdmin循环遍历展示商品图片列表的方法
2019/09/16 Javascript
[14:25]教你分分钟做大人:主宰(HEROS)
2014/12/08 DOTA
[55:56]NB vs Infamous 2019国际邀请赛淘汰赛 败者组 BO3 第二场 8.22
2019/09/05 DOTA
在Python的Flask框架中使用日期和时间的教程
2015/04/21 Python
Python 基于Twisted框架的文件夹网络传输源码
2016/08/28 Python
Python编程pygal绘图实例之XY线
2017/12/09 Python
python manage.py runserver流程解析
2019/11/08 Python
python获取网络图片方法及整理过程详解
2019/12/20 Python
Python加载数据的5种不同方式(收藏)
2020/11/13 Python
英国一家专门出售品牌鞋子的网站:Allsole
2016/08/07 全球购物
澳大利亚首个在线预订旅游网站:Wotif
2017/07/19 全球购物
中国双语服务优势的在线购票及活动平台:247tickets
2018/10/26 全球购物
大学本科毕业生求职简历的自我评价
2013/10/09 职场文书
高中地理教学反思
2014/01/29 职场文书
小学生植树节活动总结
2014/07/04 职场文书
纪念一二九运动演讲稿
2014/09/16 职场文书
2015年七七事变78周年纪念活动方案
2015/05/06 职场文书
离职告别感言
2015/08/04 职场文书
MySQL系列之四 SQL语法
2021/07/02 MySQL
Vue router配置与使用分析讲解
2022/12/24 Vue.js