pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
基于Python的关键字监控及告警
Jul 06 Python
[原创]windows下Anaconda的安装与配置正解(Anaconda入门教程)
Apr 05 Python
Python操作配置文件ini的三种方法讲解
Feb 22 Python
从0开始的Python学习014面向对象编程(推荐)
Apr 02 Python
python识别图像并提取文字的实现方法
Jun 28 Python
Python3 itchat实现微信定时发送群消息的实例代码
Jul 12 Python
在python中将list分段并保存为array类型的方法
Jul 15 Python
Tensorflow分批量读取数据教程
Feb 07 Python
推荐技术人员一款Python开源库(造数据神器)
Jul 08 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
Jan 27 Python
Python绘制K线图之可视化神器pyecharts的使用
Mar 02 Python
pytorch--之halfTensor的使用详解
May 24 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
用php随机生成福彩双色球号码的2种方法
2013/02/04 PHP
php调用KyotoTycoon简单实例
2015/04/02 PHP
在html文件中也可以执行php语句的方法
2015/04/09 PHP
php数字每三位加逗号的功能函数
2015/10/22 PHP
Thinkphp框架 表单自动验证登录注册 ajax自动验证登录注册
2016/12/27 PHP
Yii2实现中国省市区三级联动实例
2017/02/08 PHP
OAuth认证协议中的HMACSHA1加密算法(实例)
2017/10/25 PHP
JavaScript 组件之旅(三):用 Ant 构建组件
2009/10/28 Javascript
JavaScript中的数组特性介绍
2014/12/30 Javascript
ECMAScript6中Set/WeakSet详解
2015/06/12 Javascript
解决bootstrap导航栏navbar在IE8上存在缺陷的方法
2016/07/01 Javascript
Ionic+AngularJS实现登录和注册带验证功能
2017/02/09 Javascript
Vue动态实现评分效果
2017/05/24 Javascript
JavaScript实现开关等效果
2017/09/08 Javascript
Vue组件中slot的用法
2018/01/30 Javascript
AngularJS监听ng-repeat渲染完成的方法
2018/03/20 Javascript
Vue组件中prop属性使用说明实例代码详解
2018/05/31 Javascript
vue实现购物车小案例
2019/09/27 Javascript
[54:18]DOTA2-DPC中国联赛 正赛 PSG.LGD vs LBZS BO3 第一场 1月22日
2021/03/11 DOTA
Python面向对象编程中的类和对象学习教程
2015/03/30 Python
Python的IDEL增加清屏功能实例
2017/06/19 Python
Python中使用Counter进行字典创建以及key数量统计的方法
2018/07/06 Python
解决Keras 与 Tensorflow 版本之间的兼容性问题
2020/02/07 Python
利用Python中的Xpath实现一个在线汇率转换器
2020/09/09 Python
is_file和file_exists效率比较
2021/03/14 PHP
米兰网婚纱礼服法国网上商店:Milanoo法国
2016/08/20 全球购物
美国护肤咨询及美容产品电商:Askderm
2017/02/24 全球购物
如果NULL和0作为空指针常数是等价的,那我到底该用哪一个
2014/09/16 面试题
小学门卫岗位职责
2013/12/17 职场文书
村干部四风问题整改措施
2014/09/30 职场文书
四风问题自查自纠工作情况报告
2014/10/28 职场文书
2015年科协工作总结
2015/05/19 职场文书
Python基础之pandas数据合并
2021/04/27 Python
mysql备份策略的实现(全量备份+增量备份)
2021/07/07 MySQL
python可视化之颜色映射详解
2021/09/15 Python
keepalived + nginx 实现高可用方案
2022/12/24 Servers