pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
通过实例浅析Python对比C语言的编程思想差异
Aug 30 Python
简介Python设计模式中的代理模式与模板方法模式编程
Feb 02 Python
Python中operator模块的操作符使用示例总结
Jun 28 Python
python自动12306抢票软件实现代码
Feb 24 Python
Django3.0 异步通信初体验(小结)
Dec 04 Python
记一次pyinstaller打包pygame项目为exe的过程(带图片)
Mar 02 Python
4行Python代码生成图像验证码(2种)
Apr 07 Python
Python通过Pillow实现图片对比
Apr 29 Python
完美解决TensorFlow和Keras大数据量内存溢出的问题
Jul 03 Python
python基于pexpect库自动获取日志信息
Feb 01 Python
python 逐步回归算法
Apr 06 Python
详解MindSpore自定义模型损失函数
Jun 30 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
基于PHP Socket配置以及实例的详细介绍
2013/06/13 PHP
Opcache导致php-fpm崩溃nginx返回502
2015/03/02 PHP
浅谈关于PHP解决图片无损压缩的问题
2017/09/01 PHP
jquery select下拉框操作的一些说明
2010/04/02 Javascript
Jquery getJSON方法详细分析
2013/12/26 Javascript
JS小游戏之仙剑翻牌源码详解
2014/09/25 Javascript
深入浅出理解javaScript原型链
2015/05/09 Javascript
JavaScript浏览器对象之一Window对象详解
2016/06/03 Javascript
jQuery与JS加载事件用法分析
2016/09/04 Javascript
vue快捷键与基础指令详解
2017/06/01 Javascript
解决bootstrap中使用modal加载kindeditor时弹出层文本框不能输入的问题
2017/06/05 Javascript
js获取文件里面的所有文件名(实例)
2017/10/17 Javascript
微信小程序实现滴滴导航tab切换效果
2018/07/24 Javascript
详解JavaScript中关于this指向的4种情况
2019/04/18 Javascript
使用VScode 插件debugger for chrome 调试react源码的方法
2019/09/13 Javascript
关于layui 下拉列表的change事件详解
2019/09/20 Javascript
JS实现小星星特效
2019/12/24 Javascript
[04:53]DOTA2英雄基础教程 祈求者
2014/01/03 DOTA
python调用shell的方法
2013/11/20 Python
整理Python中的赋值运算符
2015/05/13 Python
利用Python脚本生成sitemap.xml的实现方法
2017/01/31 Python
python中时间转换datetime和pd.to_datetime详析
2019/08/11 Python
Python datetime包函数简单介绍
2019/08/28 Python
在keras里面实现计算f1-score的代码
2020/06/15 Python
详解如何在pyqt中通过OpenCV实现对窗口的透视变换
2020/09/20 Python
英国户外服装、鞋类和设备的领先零售商:Millets
2020/10/12 全球购物
mysql有关权限的表都有哪几个
2015/04/22 面试题
汽车专业毕业生自荐信
2013/11/03 职场文书
市政施工员自我鉴定
2014/01/15 职场文书
农林环境专业求职信
2014/03/13 职场文书
国家领导干部党的群众路线教育实践活动批评与自我批评材料
2014/09/23 职场文书
2015年科室工作总结
2015/04/10 职场文书
婚育证明样本
2015/06/16 职场文书
sql通过日期判断年龄函数的示例代码
2021/07/16 SQL Server
Go中的条件语句Switch示例详解
2021/08/23 Golang
MySQL中varchar和char类型的区别
2021/11/17 MySQL