画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中使用urllib2伪造HTTP报头的2个方法
Jul 07 Python
python处理二进制数据的方法
Jun 03 Python
Python正则表达式分组概念与用法详解
Jun 24 Python
Python二叉搜索树与双向链表转换算法示例
Mar 02 Python
python循环定时中断执行某一段程序的实例
Jun 29 Python
mac系统下Redis安装和使用步骤详解
Jul 09 Python
基于python实现蓝牙通信代码实例
Nov 19 Python
python多线程使用方法实例详解
Dec 30 Python
python 6.7 编写printTable()函数表格打印(完整代码)
Mar 25 Python
在Matplotlib图中插入LaTex公式实例
Apr 17 Python
Python列表元素删除和remove()方法详解
Jan 04 Python
安装pytorch时报sslerror错误的解决方案
May 17 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
PHP生成静态页面详解
2006/11/19 PHP
php生成随机颜色的方法
2014/11/13 PHP
PHP的PDO连接讲解
2019/01/24 PHP
escape、encodeURI、encodeURIComponent等方法的区别比较
2006/12/27 Javascript
基于Jquery的仿Windows Aero弹出窗(漂亮的关闭按钮)
2010/09/28 Javascript
js封装的textarea操作方法集合(兼容很好)
2010/11/16 Javascript
jsvascript图像处理—(计算机视觉应用)图像金字塔
2013/01/15 Javascript
JavaScript去除空格的三种方法(正则/传参函数/trim)
2013/02/06 Javascript
div模拟滚动条效果示例代码
2013/10/16 Javascript
我的Node.js学习之路(四)--单元测试
2014/07/06 Javascript
JavaScript语言精粹经典实例(整理篇)
2016/06/07 Javascript
Summernote实现图片上传功能的简单方法
2016/07/11 Javascript
快速移动鼠标触发问题及解决方法(ECharts外部调用保存为图片操作及工作流接线mouseenter和mouseleave)
2016/08/29 Javascript
javascript实现延时显示提示框效果
2017/06/01 Javascript
Angular.js中上传指令ng-upload的基本使用教程
2017/07/30 Javascript
Webpack path与publicPath的区别详解
2018/05/03 Javascript
详解vue-cli脚手架中webpack配置方法
2018/08/22 Javascript
Vue 实时监听窗口变化 windowresize的两种方法
2018/11/06 Javascript
layui监听单元格编辑前后交互的例子
2019/09/16 Javascript
Element PageHeader页头的使用方法
2020/07/26 Javascript
Vue使用v-viewer实现图片预览
2020/10/21 Javascript
在Heroku云平台上部署Python的Django框架的教程
2015/04/20 Python
总结Python编程中函数的使用要点
2016/03/20 Python
利用pandas将numpy数组导出生成excel的实例
2018/06/14 Python
在PyCharm下使用 ipython 交互式编程的方法
2019/01/17 Python
python3 Scrapy爬虫框架ip代理配置的方法
2020/01/17 Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
2020/05/26 Python
两则小学生的自我评价分享
2013/11/14 职场文书
八年级音乐教学反思
2014/01/09 职场文书
市级青年文明号申报材料
2014/05/26 职场文书
电气自动化求职信
2014/06/24 职场文书
房屋过户委托书范本
2014/10/07 职场文书
学习雷锋精神活动总结
2015/02/06 职场文书
2016年国陪研修感言
2015/11/18 职场文书
JavaScript实现复选框全选功能
2021/04/11 Javascript
Android Flutter实现图片滑动切换效果
2022/04/07 Java/Android