pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3实现短网址和数字相互转换的方法
Apr 28 Python
python简单实现刷新智联简历
Mar 30 Python
深入理解python多进程编程
Jun 12 Python
Python采用Django制作简易的知乎日报API
Aug 03 Python
浅析Python中元祖、列表和字典的区别
Aug 17 Python
Django 前后台的数据传递的方法
Aug 08 Python
python之从文件读取数据到list的实例讲解
Apr 19 Python
Python使用Windows API创建窗口示例【基于win32gui模块】
May 09 Python
对numpy Array [: ,] 的取值方法详解
Jul 02 Python
python实现图像拼接
Mar 05 Python
keras.layer.input()用法说明
Jun 16 Python
python如何写try语句
Jul 14 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
php截取html字符串及自动补全html标签的方法
2015/01/15 PHP
找到一点可怜的关于dojo资料,谢谢作者!
2006/12/06 Javascript
javascript appendChild,innerHTML,join性能比较代码
2009/08/29 Javascript
基于jquery的表格排序
2010/09/11 Javascript
js计算精度问题小结
2013/04/22 Javascript
jquery.cookie() 方法的使用(读取、写入、删除)
2013/12/05 Javascript
jQuery 和 CSS 的文本特效插件集锦
2014/12/12 Javascript
js获取内联样式的方法
2015/01/27 Javascript
超级简单实现JavaScript MVC 样式框架
2015/03/24 Javascript
JavaScript文本框脚本编写的注意事项
2016/01/25 Javascript
js获取对象、数组的实际长度,元素实际个数的实现代码
2016/06/08 Javascript
JS组件系列之使用HTML标签的data属性初始化JS组件
2016/09/14 Javascript
js移动焦点到最后位置的简单方法
2016/11/25 Javascript
jQuery基于ajax实现页面加载后检查用户登录状态的方法
2017/02/10 Javascript
原生JavaScript实现Ajax异步请求
2017/11/19 Javascript
JS实现的抛物线运动效果示例
2018/01/30 Javascript
JavaScript图片处理与合成总结
2018/03/04 Javascript
Vue 监听列表item渲染事件方法
2018/09/06 Javascript
JavaScript链式调用实例浅析
2018/12/19 Javascript
JS多个异步请求 按顺序执行next实现解析
2019/09/16 Javascript
js实现计时器秒表功能
2019/12/16 Javascript
JavaScript冒泡算法原理与实现方法深入理解
2020/06/04 Javascript
解决Vue @submit 提交后不刷新页面问题
2020/07/18 Javascript
JavaScript实现网页动态生成表格
2020/11/25 Javascript
[51:28]EG vs Mineski 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/16 DOTA
python多任务及返回值的处理方法
2019/01/22 Python
python绘制漏斗图步骤详解
2019/03/04 Python
python制作简单五子棋游戏
2019/06/18 Python
10个python爬虫入门实例(小结)
2020/11/01 Python
Spartoo瑞典:鞋子、包包和衣服
2018/09/15 全球购物
在印度上传处方,在线订购药品:Medlife
2019/03/28 全球购物
照片礼物和装饰:MyPhoto
2019/11/02 全球购物
《云雀的心愿》教学反思
2014/02/25 职场文书
股权转让协议书
2014/04/12 职场文书
2015教师见习期工作总结
2014/12/12 职场文书
运动会广播稿50字
2015/08/19 职场文书