pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用cStringIO实现临时内存文件访问的方法
Mar 26 Python
python3使用SMTP发送简单文本邮件
Jun 19 Python
Python实现的微信好友数据分析功能示例
Jun 21 Python
python单例模式获取IP代理的方法详解
Sep 13 Python
Python函数基础实例详解【函数嵌套,命名空间,函数对象,闭包函数等】
Mar 30 Python
python游戏开发之视频转彩色字符动画
Apr 26 Python
pyqt5实现按钮添加背景图片以及背景图片的切换方法
Jun 13 Python
python同步windows和linux文件
Aug 29 Python
python 发送json数据操作实例分析
Oct 15 Python
Django 如何使用日期时间选择器规范用户的时间输入示例代码详解
May 22 Python
Windows 平台做 Python 开发的最佳组合(推荐)
Jul 27 Python
安装Anaconda3及使用Jupyter的方法
Oct 27 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
PHP+Memcache实现wordpress访问总数统计(非插件)
2014/07/04 PHP
php cookie工作原理与实例详解
2016/07/18 PHP
js停止输出代码
2008/07/20 Javascript
innerhtml用法 innertext用法 以及innerHTML与innertext的区别
2009/10/26 Javascript
ExtJs的Date格式字符代码
2010/12/30 Javascript
poshytip 基于jquery的 插件 主要用于显示微博人的图像和鼠标提示等
2012/10/12 Javascript
测试IE浏览器对JavaScript的AngularJS的兼容性
2015/06/19 Javascript
jQuery和hwSlider实现内容响应式可触控滑动切换效果附源码下载(二)
2016/06/22 Javascript
NodeJS处理Express中异步错误
2017/03/26 NodeJs
underscore之Chaining_动力节点Java学院整理
2017/07/10 Javascript
js实时监控文本框输入字数的实例代码
2018/01/18 Javascript
vue移动端html5页面根据屏幕适配的四种解决方法
2018/10/19 Javascript
Vue CLI3.0中使用jQuery和Bootstrap的方法
2019/02/28 jQuery
Vue开发中常见的套路和技巧总结
2020/11/24 Vue.js
[36:17]DOTA2上海特级锦标赛 - VGL音乐会全集
2016/03/06 DOTA
python中利用zfill方法自动给数字前面补0
2018/04/10 Python
Python使用min、max函数查找二维数据矩阵中最小、最大值的方法
2018/05/15 Python
numpy使用fromstring创建矩阵的实例
2018/06/15 Python
使用python 写一个静态服务(实战)
2019/06/28 Python
Python用字典构建多级菜单功能
2019/07/11 Python
详解Python 4.0 预计推出的新功能
2019/07/26 Python
Pytorch 实现冻结指定卷积层的参数
2020/01/06 Python
python/golang 删除链表中的元素
2020/09/14 Python
出门问问全球官方商城:Tichome音箱和TicWatch智能手表
2017/12/02 全球购物
毕业生文员求职信
2013/11/03 职场文书
查环查孕证明
2014/01/10 职场文书
酒店拾金不昧表扬信
2014/01/18 职场文书
《第一朵杏花》教学反思
2014/04/16 职场文书
捐款倡议书怎么写
2014/05/13 职场文书
承诺书样本
2014/08/30 职场文书
离婚财产分配协议书
2014/10/21 职场文书
群众路线教育实践活动方案
2014/10/31 职场文书
2014年话务员工作总结
2014/11/19 职场文书
golang gopm get -g -v 无法获取第三方库的解决方案
2021/05/05 Golang
pytorch实现ResNet结构的实例代码
2021/05/17 Python
Python学习之包与模块详解
2022/03/19 Python