画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python上传package到Pypi(代码简单)
Feb 06 Python
Python引用计数操作示例
Aug 23 Python
浅谈python在提示符下使用open打开文件失败的原因及解决方法
Nov 30 Python
postman模拟访问具有Session的post请求方法
Jul 15 Python
pandas factorize实现将字符串特征转化为数字特征
Dec 19 Python
Python类继承和多态原理解析
Feb 05 Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 Python
Python Excel vlookup函数实现过程解析
Jun 22 Python
Python如何对XML 解析
Jun 28 Python
Pytorch 中的optimizer使用说明
Mar 03 Python
Python合并pdf文件的工具
Jul 01 Python
Python学习之迭代器详解
Apr 01 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
php.ini 配置文件的深入解析
2013/06/17 PHP
PHP的Yii框架的基本使用示例
2015/08/21 PHP
Yii框架引用插件和ckeditor中body与P标签去除的方法
2017/01/19 PHP
深入理解Yii2.0乐观锁与悲观锁的原理与使用
2017/07/26 PHP
PHP排序算法之快速排序(Quick Sort)及其优化算法详解
2018/04/21 PHP
PHP7.1实现的AES与RSA加密操作示例
2018/06/15 PHP
CCPry JS类库 代码
2009/10/30 Javascript
Javascript结合css实现网页换肤功能
2009/11/02 Javascript
基于jQuery实现左右div自适应高度完全相同的代码
2012/08/09 Javascript
js调用AJAX时Get和post的乱码解决方法
2013/06/04 Javascript
深入理解JavaScript系列(45):代码复用模式(避免篇)详解
2015/03/04 Javascript
一些实用性较高的js方法
2016/04/19 Javascript
微信小程序实现图片轮播及文件上传
2017/04/07 Javascript
浅谈ES6新增的数组方法和对象
2017/08/08 Javascript
javascript异步处理与Jquery deferred对象用法总结
2019/06/04 jQuery
微信小程序实现手指拖动选项排序
2020/04/22 Javascript
Python的Django REST框架中的序列化及请求和返回
2016/04/11 Python
django 通过ajax完成邮箱用户注册、激活账号的方法
2018/04/17 Python
Python使用pyautogui模块实现自动化鼠标和键盘操作示例
2018/09/04 Python
python实现抽奖小程序
2020/04/15 Python
python制作简单五子棋游戏
2019/06/18 Python
python 随机森林算法及其优化详解
2019/07/11 Python
Python操作多维数组输出和矩阵运算示例
2019/11/28 Python
多个python文件调用logging模块报错误
2020/02/12 Python
Python多线程Threading、子线程与守护线程实例详解
2020/03/24 Python
pyqt5数据库使用详细教程(打包解决方案)
2020/03/25 Python
利用Python如何实时检测自身内存占用
2020/05/09 Python
linux系统下pip升级报错的解决方法
2021/01/31 Python
Giglio德国网上精品店:奢侈品服装和配件
2016/09/23 全球购物
意大利值得信赖的在线超级药房:PillolaStore
2020/02/05 全球购物
销售经理工作职责范文
2013/12/03 职场文书
怎样写好自荐信和推荐信
2013/12/26 职场文书
教师自查自纠工作情况报告
2014/10/29 职场文书
医德医风学习心得体会
2016/01/25 职场文书
python spilt()分隔字符串的实现示例
2021/05/21 Python
SQL注入篇学习之盲注/宽字节注入
2022/03/03 MySQL