pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
十条建议帮你提高Python编程效率
Feb 16 Python
python中set常用操作汇总
Jun 30 Python
Python实现读取txt文件并画三维图简单代码示例
Dec 09 Python
Python SVM(支持向量机)实现方法完整示例
Jun 19 Python
Python 打印中文字符的三种方法
Aug 14 Python
对Python 多线程统计所有csv文件的行数方法详解
Feb 12 Python
一文了解Python并发编程的工程实现方法
May 31 Python
django model 条件过滤 queryset.filter(**condtions)用法详解
May 20 Python
Python OpenCV中的numpy与图像类型转换操作
Dec 11 Python
Python3中对json格式数据的分析处理
Jan 28 Python
Python机器学习工具scikit-learn的使用笔记
Jan 28 Python
python中time tzset()函数实例用法
Feb 18 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
在线短消息收发的程序,不用数据库
2006/10/09 PHP
PHP实现MVC开发得最简单的方法――模型
2007/04/10 PHP
php Memcache 中实现消息队列
2009/11/24 PHP
php 5.3.5安装memcache注意事项小结
2011/04/12 PHP
AES加解密在php接口请求过程中的应用示例
2016/10/26 PHP
JavaScript 变量命名规则
2009/09/23 Javascript
跟着JQuery API学Jquery 之二 属性
2010/04/09 Javascript
A标签中通过href和onclick传递的this对象实现思路
2013/04/19 Javascript
jQuery制作仿Mac Lion OS滚动条效果
2015/02/10 Javascript
使用JavaScript的AngularJS库编写hello world的方法
2015/06/23 Javascript
jQuery学习笔记之Ajax用法实例详解
2015/12/01 Javascript
JavaScript编写一个简易购物车功能
2016/09/17 Javascript
微信小程序 配置文件详细介绍
2016/12/14 Javascript
javaScript嗅探执行神器-sniffer.js
2017/02/14 Javascript
jquery获取select,option所有的value和text的实例
2017/03/06 Javascript
jQuery实现的点击标题文字切换字体效果示例【测试可用】
2018/04/26 jQuery
Vue 使用 Mint UI 实现左滑删除效果CellSwipe
2018/04/27 Javascript
Vue一次性简洁明了引入所有公共组件的方法
2018/11/28 Javascript
Javascript实现鼠标点击冒泡特效
2019/12/24 Javascript
[01:10:57]Liquid vs OG 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python循环监控远程端口的方法
2015/03/14 Python
Python实现图片滑动式验证识别方法
2017/11/09 Python
python学习笔记--将python源文件打包成exe文件(pyinstaller)
2018/05/26 Python
Python退火算法在高次方程的应用
2018/07/26 Python
Python+threading模块对单个接口进行并发测试
2019/06/25 Python
pycharm 更改创建文件默认路径的操作
2020/02/15 Python
pytorch读取图像数据转成opencv格式实例
2020/06/02 Python
Python爬虫之Selenium警告框(弹窗)处理
2020/12/04 Python
用Python实现童年贪吃蛇小游戏功能的实例代码
2020/12/07 Python
美国求婚钻戒网站:Super Jeweler
2016/08/27 全球购物
英国领先的运动营养品牌:Protein Dynamix
2018/01/02 全球购物
德尔福集团DELPHI的笔试题
2012/02/22 面试题
写好自荐信的技巧
2013/11/08 职场文书
教育实习生的自我评价分享
2013/11/21 职场文书
排查整治工作方案
2014/06/09 职场文书
php远程请求CURL案例(爬虫、保存登录状态)
2021/04/01 PHP