pytorch打印网络结构的实例


Posted in Python onAugust 19, 2019

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。

(2)生成网络结构的代码:

def make_dot(var, params=None):
  """ Produces Graphviz representation of PyTorch autograd graph
  Blue nodes are the Variables that require grad, orange are Tensors
  saved for backward in torch.autograd.Function
  Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
      require grad (TODO: make optional)
  """
  if params is not None:
    assert isinstance(params.values()[0], Variable)
    param_map = {id(v): k for k, v in params.items()}
 
  node_attr = dict(style='filled',
           shape='box',
           align='left',
           fontsize='12',
           ranksep='0.1',
           height='0.2')
  dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
  seen = set()
 
  def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'
  def add_nodes(var):
    if var not in seen:
      if torch.is_tensor(var):
        dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
      elif hasattr(var, 'variable'):
        u = var.variable
        name = param_map[id(u)] if params is not None else ''
        node_name = '%s\n %s' % (name, size_to_str(u.size()))
        dot.node(str(id(var)), node_name, fillcolor='lightblue')
      else:
        dot.node(str(id(var)), str(type(var).__name__))
      seen.add(var)
      if hasattr(var, 'next_functions'):
        for u in var.next_functions:
          if u[0] is not None:
            dot.edge(str(id(u[0])), str(id(var)))
            add_nodes(u[0])
      if hasattr(var, 'saved_tensors'):
        for t in var.saved_tensors:
          dot.edge(str(id(t)), str(id(var)))
          add_nodes(t)
  add_nodes(var.grad_fn)
  return dot

(3)打印网络结构:

import torch 
from torch.autograd import Variable 
import torch.nn as nn 
from graphviz import Digraph
 
class CNN(nn.module):
  def __init__(self):
   ******
   def forward(self,x):
   ******
   return out
 
*****************************
def make_dot(): #复制上面的代码
*****************************
 
if __name__ == '__main__': 
  net = CNN() 
  x = Variable(torch.randn(1, 1, 1024,1024)) 
  y = net(x) 
  g = make_dot(y) 
  g.view() 
 
  params = list(net.parameters()) 
  k = 0 
  for i in params: 
    l = 1 
    print("该层的结构:" + str(list(i.size()))) 
    for j in i.size(): 
      l *= j 
    print("该层参数和:" + str(l)) 
    k = k + l 
  print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构的实例

以上这篇pytorch打印网络结构的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现apahce网站日志分析示例
Apr 02 Python
理解python多线程(python多线程简明教程)
Jun 09 Python
python中pycurl库的用法实例
Sep 30 Python
Windows上使用virtualenv搭建Python+Flask开发环境
Jun 07 Python
Python连接phoenix的方法示例
Sep 29 Python
使用Python实现在Windows下安装Django
Oct 17 Python
Python 通过requests实现腾讯新闻抓取爬虫的方法
Feb 22 Python
pycharm访问mysql数据库的方法步骤
Jun 18 Python
Python3之不使用第三方变量,实现交换两个变量的值
Jun 26 Python
在Keras中实现保存和加载权重及模型结构
Jun 15 Python
python安装后的目录在哪里
Jun 21 Python
如何用六步教会你使用python爬虫爬取数据
Apr 06 Python
pytorch索引查找 index_select的例子
Aug 18 #Python
浅谈Pytorch中的torch.gather函数的含义
Aug 18 #Python
PyTorch中Tensor的维度变换实现
Aug 18 #Python
PyTorch中Tensor的拼接与拆分的实现
Aug 18 #Python
详解PyTorch中Tensor的高阶操作
Aug 18 #Python
浅析PyTorch中nn.Linear的使用
Aug 18 #Python
Pytorch实现GoogLeNet的方法
Aug 18 #Python
You might like
PHP脚本数据库功能详解(中)
2006/10/09 PHP
php下MYSQL limit的优化
2008/01/10 PHP
PHP session有效期session.gc_maxlifetime
2011/04/20 PHP
php中替换字符串中的空格为逗号','的方法
2014/06/09 PHP
PHP实现无限分类的实现方法
2016/11/14 PHP
Laravel中的chunk组块结果集处理与注意问题
2018/08/15 PHP
php集成开发环境详解
2019/09/24 PHP
js监听表单value的修改同步问题,跨浏览器支持
2009/12/31 Javascript
JQuery Ajax通过Handler访问外部XML数据的代码
2010/06/01 Javascript
jquery实现盒子下拉效果示例代码
2013/09/12 Javascript
jquery实现简易的移动端验证表单
2015/11/08 Javascript
Ionic2系列之使用DeepLinker实现指定页面URL
2016/11/21 Javascript
微信小程序 列表的上拉加载和下拉刷新的实现
2017/04/01 Javascript
node.js + socket.io 实现点对点随机匹配聊天
2017/06/30 Javascript
jackson解析json字符串,首字母大写会自动转为小写的方法
2017/12/22 Javascript
小程序实现左滑删除效果
2019/07/25 Javascript
js实现微信聊天效果
2020/08/09 Javascript
JS绘图Flot如何实现动态可刷新曲线图
2020/10/16 Javascript
python实现支持目录FTP上传下载文件的方法
2015/06/03 Python
python在控制台输出进度条的方法
2015/06/20 Python
Python基于PycURL实现POST的方法
2015/07/25 Python
Python 专题六 局部变量、全局变量global、导入模块变量
2017/03/20 Python
一个基于flask的web应用诞生 记录用户账户登录状态(6)
2017/04/11 Python
Python 一键获取百度网盘提取码的方法
2019/08/01 Python
python点击鼠标获取坐标(Graphics)
2019/08/10 Python
python numpy之np.random的随机数函数使用介绍
2019/10/06 Python
Python 取numpy数组的某几行某几列方法
2019/10/24 Python
浅谈Django中的QueryDict元素为数组的坑
2020/03/31 Python
python selenium 获取接口数据的实现
2020/12/07 Python
HTML5标签大全
2016/11/23 HTML / CSS
JPA的优势都有哪些
2013/07/04 面试题
电话销售经理岗位职责
2013/12/07 职场文书
大二学期个人自我评价
2014/01/13 职场文书
酒店采购员岗位职责
2014/03/14 职场文书
学校组织向国旗敬礼活动方案(中小学适用)
2014/09/27 职场文书
深入理解go slice结构
2021/09/15 Golang