pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Flask框架的学习指南之开发环境搭建
Nov 20 Python
django开发教程之利用缓存文件进行页面缓存的方法
Nov 10 Python
Tornado高并发处理方法实例代码
Jan 15 Python
Python函数基础实例详解【函数嵌套,命名空间,函数对象,闭包函数等】
Mar 30 Python
Flask-WTF表单的使用方法
Jul 12 Python
django多对多表的创建,级联删除及手动创建第三张表
Jul 25 Python
PIL对上传到Django的图片进行处理并保存的实例
Aug 07 Python
Python猴子补丁知识点总结
Jan 05 Python
Python实现子类调用父类的初始化实例
Mar 12 Python
Tensorflow中的降维函数tf.reduce_*使用总结
Apr 20 Python
jupyter notebook 添加kernel permission denied的操作
Apr 21 Python
python pygame 愤怒的小鸟游戏示例代码
Feb 25 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
php实现执行某一操作时弹出确认、取消对话框
2013/12/30 PHP
php将一维数组转换为每3个连续值组成的二维数组
2016/05/06 PHP
PHP中关于php.ini参数优化详解
2020/02/28 PHP
求解开jscript.encode代码的asp函数
2007/02/28 Javascript
FF IE兼容性的修改小结
2009/09/02 Javascript
jquery中使用$(#form).submit()重写提交表单无效原因分析及解决
2013/03/25 Javascript
ajax与302响应代码测试
2013/10/23 Javascript
jQuery获得页面元素的绝对/相对位置即绝对X,Y坐标
2014/03/06 Javascript
JavaScript页面模板库handlebars的简单用法
2015/03/02 Javascript
针对BootStrap中tabs控件的美化和完善(推荐)
2016/07/06 Javascript
深入剖析JavaScript面向对象编程
2016/07/12 Javascript
node.js报错:Cannot find module 'ejs'的解决办法
2016/12/14 Javascript
jQuery 全选 全不选 事件绑定的实现代码
2017/01/23 Javascript
微信小程序实现图片轮播及文件上传
2017/04/07 Javascript
layui实现二维码弹窗、并下载到本地的方法
2019/09/25 Javascript
JavaScript的一些小技巧分享
2021/01/06 Javascript
跟老齐学Python之有容乃大的list(4)
2014/09/28 Python
Python中用于返回绝对值的abs()方法
2015/05/14 Python
十分钟利用Python制作属于你自己的个性logo
2018/05/07 Python
Python实现连接MySql数据库及增删改查操作详解
2019/04/16 Python
python实现视频分帧效果
2019/05/31 Python
python实现单目标、多目标、多尺度、自定义特征的KCF跟踪算法(实例代码)
2020/01/08 Python
Python+OpenCV实现旋转文本校正方式
2020/01/09 Python
pycharm中选中一个单词替换所有重复单词的实现方法
2020/11/17 Python
美国最大的香水连锁店官网:Perfumania
2016/08/15 全球购物
Holland & Barrett爱尔兰:英国领先的健康零售商
2019/03/31 全球购物
小车司机岗位职责
2013/11/25 职场文书
房产转让协议书
2014/04/11 职场文书
诚信考试倡议书
2014/04/15 职场文书
应聘护士求职信
2014/07/21 职场文书
储备店长岗位职责
2015/04/14 职场文书
MySQL Router的安装部署
2021/04/24 MySQL
JVM上高性能数据格式库包Apache Arrow入门和架构详解(Gkatziouras)
2021/05/26 Servers
JavaScript中MutationObServer监听DOM元素详情
2021/11/27 Javascript
Redis命令处理过程源码解析
2022/02/12 Redis
如何通过一篇文章了解Python中的生成器
2022/04/02 Python