pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
以Python的Pyspider为例剖析搜索引擎的网络爬虫实现方法
Mar 30 Python
浅析Python编写函数装饰器
Mar 18 Python
详谈套接字中SO_REUSEPORT和SO_REUSEADDR的区别
Apr 28 Python
Jupyter中直接显示Matplotlib的图形方法
May 24 Python
python验证码识别教程之利用滴水算法分割图片
Jun 05 Python
Python引用计数操作示例
Aug 23 Python
Python实现二维曲线拟合的方法
Dec 29 Python
python3实现二叉树的遍历与递归算法解析(小结)
Jul 03 Python
基于python实现计算且附带进度条代码实例
Mar 31 Python
Python绘制全球疫情变化地图的实例代码
Apr 20 Python
win10从零安装配置pytorch全过程图文详解
May 08 Python
Django-Scrapy生成后端json接口的方法示例
Oct 06 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
使用Apache的rewrite技术
2006/06/22 PHP
模仿OSO的论坛(二)
2006/10/09 PHP
解析php5配置使用pdo
2013/07/03 PHP
PHP简单实现数字分页功能示例
2016/08/24 PHP
php 三大特点:封装,继承,多态
2017/02/19 PHP
javascript 静态对象和构造函数的使用和公私问题
2010/03/02 Javascript
getElementByIdx_x js自定义getElementById函数
2012/01/24 Javascript
十个迅速提升JQuery性能让你的JQuery跑得更快
2012/12/10 Javascript
javascript数组快速打乱重排的方法
2014/01/02 Javascript
JS如何将数字类型转化为没3个一个逗号的金钱格式
2014/01/27 Javascript
使用js画图之画切线
2015/01/12 Javascript
微信小程序使用第三方库Underscore.js步骤详解
2016/09/27 Javascript
利用vue + element实现表格分页和前端搜索的方法
2017/12/25 Javascript
React Native使用fetch实现图片上传的示例代码
2018/03/07 Javascript
微信小程序云开发修改云数据库中的数据方法
2019/05/18 Javascript
JavaScript实现轮播图效果代码实例
2019/09/28 Javascript
解决Vue的项目使用Element ui 走马灯无法实现的问题
2020/08/03 Javascript
如何使用 JavaScript 操作浏览器历史记录 API
2020/11/24 Javascript
详解Vue 的异常处理机制
2020/11/30 Vue.js
jQuery实现简单轮播图效果
2020/12/27 jQuery
python获取文件后缀名及批量更新目录下文件后缀名的方法
2014/11/11 Python
Python if语句知识点用法总结
2018/06/10 Python
python 阶乘累加和的实例
2019/02/01 Python
Python实现带下标索引的遍历操作示例
2019/05/30 Python
django queryset 去重 .distinct()说明
2020/05/19 Python
Python爬虫之Selenium中frame/iframe表单嵌套页面
2020/12/04 Python
python中四舍五入的正确打开方式
2021/01/18 Python
4S店售后客服自我评价
2014/04/09 职场文书
社区文艺活动方案
2014/08/19 职场文书
红色旅游心得体会
2014/09/03 职场文书
运动会搞笑广播稿
2014/10/14 职场文书
2014年服务员工作总结
2014/11/18 职场文书
食品质检员岗位职责
2015/04/08 职场文书
2015年保管员工作总结
2015/04/30 职场文书
golang 如何用反射reflect操作结构体
2021/04/28 Golang
Golang 遍历二叉树
2022/04/19 Golang