画pytorch模型图,以及参数计算的方法


Posted in Python onAugust 17, 2019

刚入pytorch的坑,代码还没看太懂。之前用keras用习惯了,第一次使用pytorch还有些不适应,希望广大老司机多多指教。

首先说说,我们如何可视化模型。在keras中就一句话,keras.summary(),或者plot_model(),就可以把模型展现的淋漓尽致。

但是pytorch中好像没有这样一个api让我们直观的看到模型的样子。但是有网友提供了一段代码,可以把模型画出来,对我来说简直就是如有神助啊。

话不多说,上代码吧。

import torch
from torch.autograd import Variable
import torch.nn as nn
from graphviz import Digraph
 
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
      nn.ReLU(),
      nn.MaxPool2d(kernel_size=2)
    )
    self.out = nn.Linear(32*7*7, 10)
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1) # (batch, 32*7*7)
    out = self.out(x)
    return out
 
 
def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
 
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot
 
 
if __name__ == '__main__':
  net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()
 
  params = list(net.parameters())
  k = 0
  for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
      l *= j
    print("该层参数和:" + str(l))
    k = k + l
  print("总参数数量和:" + str(k))

模型很简单,代码也很简单。就是conv -> relu -> maxpool -> conv -> relu -> maxpool -> fc

大家在可视化的时候,直接复制make_dot那段代码即可,然后需要初始化一个net,以及这个网络需要的数据规模,此处就以 这段代码为例,初始化一个模型net,准备这个模型的输入数据x,shape为(batch,channels,height,width) 然后把数据传入模型得到输出结果y。传入make_dot即可得到下图。

net = CNN()
  x = Variable(torch.randn(1, 1, 28, 28))
  y = net(x)
  g = make_dot(y)
  g.view()

画pytorch模型图,以及参数计算的方法

最后输出该网络的各种参数。

该层的结构:[16, 1, 5, 5]
该层参数和:400
该层的结构:[16]
该层参数和:16
该层的结构:[32, 16, 5, 5]
该层参数和:12800
该层的结构:[32]
该层参数和:32
该层的结构:[10, 1568]
该层参数和:15680
该层的结构:[10]
该层参数和:10
总参数数量和:28938

以上这篇画pytorch模型图,以及参数计算的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详细介绍Python中的偏函数
Apr 27 Python
Python解析excel文件存入sqlite数据库的方法
Nov 15 Python
Python实现简单的HttpServer服务器示例
Sep 25 Python
Python+Django搭建自己的blog网站
Mar 13 Python
python pandas.DataFrame选取、修改数据最好用.loc,.iloc,.ix实现
Jun 11 Python
对pandas读取中文unicode的csv和添加行标题的方法详解
Dec 12 Python
python远程连接MySQL数据库
Apr 19 Python
浅谈keras的深度模型训练过程及结果记录方式
Jan 24 Python
Python如何实现FTP功能
May 28 Python
Pytorch中TensorBoard及torchsummary的使用详解
May 12 Python
Python socket如何解析HTTP请求内容
Feb 12 Python
python实现双向链表原理
May 25 Python
pytorch 共享参数的示例
Aug 17 #Python
Pytorch卷积层手动初始化权值的实例
Aug 17 #Python
pytorch自定义初始化权重的方法
Aug 17 #Python
在Pytorch中使用样本权重(sample_weight)的正确方法
Aug 17 #Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
You might like
php 删除记录实现代码
2009/03/12 PHP
PHP实现采集抓取淘宝网单个商品信息
2015/01/08 PHP
实例讲解YII2中多表关联的使用方法
2017/07/21 PHP
零基础php编程好学吗
2019/10/11 PHP
如何用javascript控制上传文件的大小
2006/10/26 Javascript
javascript 有用的脚本函数
2009/05/07 Javascript
小议Javascript中的this指针
2010/03/18 Javascript
JavaScript 对象的属性和方法4种不同的类型
2010/03/19 Javascript
用JS控制回车事件的代码
2011/02/20 Javascript
javascript学习笔记(六) Date 日期类型
2012/06/19 Javascript
javaScript中Math()函数注意事项
2015/06/18 Javascript
轻松掌握jQuery中wrap()与unwrap()函数的用法
2016/05/24 Javascript
性能优化之代码优化页面加载速度
2017/03/01 Javascript
微信小程序 页面传值详解
2017/03/10 Javascript
js实现日期显示的一些操作(实例讲解)
2017/07/27 Javascript
jQuery easyui datagird编辑行删除行功能的实现代码
2018/09/20 jQuery
vue 解决addRoutes多次添加路由重复的操作
2020/08/04 Javascript
python备份文件以及mysql数据库的脚本代码
2013/06/10 Python
利用Python写一个爬妹子的爬虫
2018/06/08 Python
python pygame模块编写飞机大战
2018/11/20 Python
python 实现多维数组转向量
2019/11/30 Python
pytorch使用 to 进行类型转换方式
2020/01/08 Python
python实现密度聚类(模板代码+sklearn代码)
2020/04/27 Python
基于PyQT实现区分左键双击和单击
2020/05/19 Python
CSS3实现线性渐变用法示例代码详解
2020/08/07 HTML / CSS
ECCO爱步加拿大官网:北欧丹麦鞋履及皮具品牌
2017/07/08 全球购物
Proenza Schouler官方网站:纽约女装和配饰品牌
2019/01/03 全球购物
Lulu Guinness露露·吉尼斯官网:红唇包
2019/02/03 全球购物
综合实践活动方案
2014/02/14 职场文书
工作迟到检讨书
2014/02/21 职场文书
阳光体育活动实施方案
2014/05/25 职场文书
go语言基础 seek光标位置os包的使用
2021/05/09 Golang
MySQL 常见存储引擎的优劣
2021/06/02 MySQL
Apache Hudi数据布局黑科技降低一半查询时间
2022/03/31 Servers
Python如何快速找到多个字典中的公共键(key)
2022/04/29 Python
VMware虚拟机安装 Windows Server 2022的详细图文教程
2022/09/23 Servers