pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python转换HTML到Text纯文本的方法
Jan 15 Python
简单的python后台管理程序
Apr 13 Python
Python中模块pymysql查询结果后如何获取字段列表
Jun 05 Python
pygame游戏之旅 如何制作游戏障碍
Nov 20 Python
Python创建字典的八种方式
Feb 27 Python
通过python实现随机交换礼物程序详解
Jul 10 Python
完美解决python3.7 pip升级 拒绝访问问题
Jul 12 Python
Python爬虫学习之获取指定网页源码
Jul 30 Python
Python坐标线性插值应用实现
Nov 13 Python
让Django的BooleanField支持字符串形式的输入方式
May 20 Python
使用python脚本自动生成K8S-YAML的方法示例
Jul 12 Python
python识别围棋定位棋盘位置
Jul 26 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
php中用数组的方法设置cookies
2011/04/21 PHP
PHP二维数组矩形转置实例
2016/07/20 PHP
Laravel下生成验证码的类
2017/11/15 PHP
动态加载外部javascript文件的函数代码分享
2011/07/28 Javascript
得到jQuery detach()后节点中的某个值实现代码
2013/02/05 Javascript
jquery异步跨域访问代码
2013/06/28 Javascript
ExtJS4 表格的嵌套 rowExpander应用
2014/05/02 Javascript
ie8模式下click无反应点击option无反应的解决方法
2014/10/11 Javascript
jQuery实现选中弹出窗口选择框内容后赋值给文本框的方法
2015/11/23 Javascript
基于jquery实现页面滚动时顶部导航显示隐藏
2020/04/20 Javascript
原生js实现手风琴功能(支持横纵向调用)
2017/01/13 Javascript
利用jQuery实现简单的拖曳效果实例代码
2017/10/20 jQuery
JavaScript 异步调用
2017/10/25 Javascript
微信小程序如何再次获取用户授权的方法
2019/05/10 Javascript
小程序实现上下移动切换位置
2019/09/23 Javascript
JavaScript/TypeScript 实现并发请求控制的示例代码
2021/01/18 Javascript
python开发之IDEL(Python GUI)的使用方法图文详解
2015/11/12 Python
python嵌套函数使用外部函数变量的方法(Python2和Python3)
2016/01/31 Python
在Python中实现替换字符串中的子串的示例
2018/10/31 Python
如何通过50行Python代码获取公众号全部文章
2019/07/12 Python
如何使用python代码操作git代码
2020/02/29 Python
Python之Django自动实现html代码(下拉框,数据选择)
2020/03/13 Python
详解Python openpyxl库的基本应用
2021/02/26 Python
俄罗斯游戏商店:Buka
2020/03/01 全球购物
Java的类可以定义为Protected或者Private得吗
2015/09/25 面试题
计算机应用专业推荐信
2013/11/13 职场文书
危爆物品安全大检查大整治工作方案
2014/05/03 职场文书
党务公开方案
2014/05/06 职场文书
巾帼志愿者活动方案
2014/08/17 职场文书
党员个人对照检查材料思想汇报
2014/09/16 职场文书
餐饮食品安全责任书
2015/01/29 职场文书
二十年同学聚会感言
2015/07/30 职场文书
HR在给员工开具离职证明时,需要注意哪些问题?
2019/07/03 职场文书
解决jupyter notebook图片显示模糊和保存清晰图片的操作
2021/04/24 Python
MySQL Router的安装部署
2021/04/24 MySQL
django如何自定义manage.py管理命令
2021/04/27 Python