pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Windows下安装python2.7及科学计算套装
Mar 05 Python
利用Python绘制数据的瀑布图的教程
Apr 07 Python
Python while、for、生成器、列表推导等语句的执行效率测试
Jun 03 Python
详谈Python2.6和Python3.0中对除法操作的异同
Apr 28 Python
Python识别快递条形码及Tesseract-OCR使用详解
Jul 15 Python
tensorflow指定GPU与动态分配GPU memory设置
Feb 03 Python
python3.6.8 + pycharm + PyQt5 环境搭建的图文教程
Jun 11 Python
python如何查看安装了的模块
Jun 23 Python
django 将自带的数据库sqlite3改成mysql实例
Jul 09 Python
Python 代码调试技巧示例代码
Aug 11 Python
python logging模块的使用详解
Oct 23 Python
去除python中的字符串空格的简单方法
Dec 22 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
《魔兽争霸3》重制版究竟重制了什么?玩家:这么糊弄真的好吗?
2020/05/04 魔兽争霸
php数据库抽象层 PDO
2011/05/07 PHP
基于PHP CURL获取邮箱地址的详解
2013/06/03 PHP
PHP中使用glob函数实现一句话删除某个目录下的所有文件
2014/07/22 PHP
PHP实现GIF图片验证码
2015/11/04 PHP
PHP实现UTF8二进制及明文字符串的转化功能示例
2017/11/20 PHP
Laravel框架处理用户的请求操作详解
2019/12/20 PHP
贴一个在Mozilla中常用的Javascript代码
2007/01/09 Javascript
dtree 网页树状菜单及传递对象集合到js内,动态生成节点
2012/04/14 Javascript
jQuery操作checkbox选择(list/table)
2013/04/07 Javascript
关于extjs4如何获取grid修改后的数据的问题
2013/08/07 Javascript
javascipt匹配单行和多行注释的正则表达式
2013/11/20 Javascript
jQuery zTree加载树形菜单功能
2016/02/25 Javascript
从重置input file标签中看jQuery的 .val() 和 .attr(“value”) 区别
2016/06/12 Javascript
AngularJS教程之MVC体系结构详解
2016/08/16 Javascript
详谈js中window.location.search的用法和作用
2017/02/13 Javascript
判断横屏竖屏(三种)
2017/02/13 Javascript
jQuery插件FusionCharts实现的3D柱状图效果实例【附demo源码下载】
2017/03/03 Javascript
babel的使用及安装配置教程
2018/02/22 Javascript
ES7之Async/await的使用详解
2019/03/28 Javascript
使vue实现jQuery调用的两种方法
2019/05/12 jQuery
seajs和requirejs模块化简单案例分析
2019/08/26 Javascript
[02:40]2014DOTA2 国际邀请赛中国区预选赛 四大豪门抵达华西村
2014/05/23 DOTA
[35:39]完美世界DOTA2联赛PWL S2 FTD.C vs Rebirth 第二场 11.22
2020/11/24 DOTA
python实现根据主机名字获得所有ip地址的方法
2015/06/28 Python
Mac下Anaconda的安装和使用教程
2018/11/29 Python
Python微医挂号网医生数据抓取
2019/01/24 Python
django商品分类及商品数据建模实例详解
2020/01/03 Python
Python 简单计算要求形状面积的实例
2020/01/18 Python
html5 分层屏幕适配的方法
2018/03/16 HTML / CSS
HTML5是否真的可以取代Flash
2010/02/10 HTML / CSS
Mio Skincare美国官网:身体紧致及孕期身体护理
2017/03/05 全球购物
大学毕业的自我鉴定
2013/10/08 职场文书
新闻记者个人求职的自我评价
2013/11/28 职场文书
加拿大探亲邀请信
2014/01/28 职场文书
安全生产先进个人总结
2015/02/15 职场文书