画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络编程学习笔记(四):域名系统
Jun 09 Python
利用Python破解验证码实例详解
Dec 08 Python
numpy实现合并多维矩阵、list的扩展方法
May 08 Python
Python之批量创建文件的实例讲解
May 10 Python
python pandas实现excel转为html格式的方法
Oct 23 Python
Python编程实现tail-n查看日志文件的方法
Jul 08 Python
python或C++读取指定文件夹下的所有图片
Aug 31 Python
python matplotlib包图像配色方案分享
Mar 14 Python
浅谈Python中os模块及shutil模块的常规操作
Apr 03 Python
keras model.fit 解决validation_spilt=num 的问题
Jun 19 Python
Python Django路径配置实现过程解析
Nov 05 Python
python迷宫问题深度优先遍历实例
Jun 20 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
php制作unicode解码工具(unicode编码转换器)代码分享
2013/12/24 PHP
php教程之phpize使用方法
2014/02/12 PHP
php中socket通信机制实例详解
2015/01/03 PHP
PHP与SQL语句常用大全
2016/12/10 PHP
js活用事件触发对象动作
2008/08/10 Javascript
Div自动滚动到末尾的代码
2008/10/26 Javascript
jQuery ajax cache缓存问题
2010/07/01 Javascript
基于jquery的实现简单的表格中增加或删除下一行
2010/08/01 Javascript
JavaScript聚焦于第一个字段的代码
2010/10/15 Javascript
给页面渲染时间加速 干掉Dom Level 0 Event
2012/12/19 Javascript
简单的两种Extjs formpanel加载数据的方式
2013/11/09 Javascript
Node.js的包详细介绍
2015/01/14 Javascript
JS使用cookie实现DIV提示框只显示一次的方法
2015/11/05 Javascript
利用node.js爬取指定排名网站的JS引用库详解
2017/07/25 Javascript
详解webpack + react + react-router 如何实现懒加载
2017/11/20 Javascript
Vue实现导航栏点击当前标签变色功能
2020/08/19 Javascript
node+multer实现图片上传的示例代码
2020/02/18 Javascript
Python字符转换
2008/09/06 Python
对python 匹配字符串开头和结尾的方法详解
2018/10/27 Python
django 多对多表的创建和插入代码实现
2019/09/09 Python
python实现输入任意一个大写字母生成金字塔的示例
2019/10/27 Python
python对Excel按条件进行内容补充(推荐)
2019/11/24 Python
Python pip配置国内源的方法
2020/02/14 Python
jupyter修改文件名方式(TensorFlow)
2020/04/21 Python
一篇文章带你搞定Ubuntu中打开Pycharm总是卡顿崩溃
2020/11/02 Python
Python web框架(django,flask)实现mysql数据库读写分离的示例
2020/11/18 Python
英国演唱会订票网站:Ticket Selection
2018/03/27 全球购物
老板电器官方购物商城:老板油烟机、燃气灶、消毒柜、电烤箱
2018/05/30 全球购物
行政副总岗位职责
2014/02/23 职场文书
态度决定一切演讲稿
2014/05/20 职场文书
大专生找工作自荐书
2014/06/10 职场文书
个人总结格式范文
2015/03/09 职场文书
2015年防汛工作总结
2015/05/15 职场文书
2015年防灾减灾工作总结
2015/07/24 职场文书
python opencv旋转图片的使用方法
2021/06/04 Python
ubuntu安装jupyter并设置远程访问的实现
2022/03/31 Python