pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中逗号的三种作用实例分析
Jun 08 Python
python编程线性回归代码示例
Dec 07 Python
解决pip install的时候报错timed out的问题
Jun 12 Python
django2.0扩展用户字段示例
Feb 13 Python
详解Python3网络爬虫(二):利用urllib.urlopen向有道翻译发送数据获得翻译结果
May 07 Python
python获取点击的坐标画图形的方法
Jul 09 Python
将labelme格式数据转化为标准的coco数据集格式方式
Feb 17 Python
Python中求对数方法总结
Mar 10 Python
Python实现Telnet自动连接检测密码的示例
Apr 16 Python
python编写五子棋游戏
May 25 Python
python 如何用terminal输入参数
May 25 Python
Python加密与解密模块hashlib与hmac
Jun 05 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
PHP代码保护--Zend Guard的使用详解
2013/06/03 PHP
推荐一款MAC OS X 下php集成开发环境mamp
2014/11/08 PHP
php使用iconv中文截断问题的解决方法
2015/02/11 PHP
PHP的文件操作与算法实现的面试题示例
2015/08/10 PHP
Yii编程开发常见调用技巧集锦
2016/07/15 PHP
php实现的pdo公共类定义与用法示例
2017/07/19 PHP
PHP文件管理之实现网盘及压缩包的功能操作
2017/09/20 PHP
Javascript基础知识(一)核心基础语法与事件模型
2014/09/29 Javascript
JS仿iGoogle自定义首页模块拖拽特效的方法
2015/02/13 Javascript
JavaScript每天定时更换皮肤样式的方法
2015/07/01 Javascript
JavaScript如何实现组合列表框中元素移动效果
2016/03/01 Javascript
js 转json格式的字符串为对象或数组(前后台)的方法
2016/11/02 Javascript
JavaScript中for循环的几种写法与效率总结
2017/02/03 Javascript
es6学习笔记之Async函数的使用示例
2017/05/11 Javascript
浅谈Node.js之异步流控制
2017/10/25 Javascript
[原创]微信小程序获取网络类型的方法示例
2019/03/01 Javascript
解决layui-table单元格设置为百分比在ie8下不能自适应的问题
2019/09/28 Javascript
JS实现audio音频剪裁剪切复制播放与上传(步骤详解)
2020/07/28 Javascript
浅谈vue获得后台数据无法显示到table上面的坑
2020/08/13 Javascript
Django中多种重定向方法使用详解
2019/07/17 Python
Python随机函数库random的使用方法详解
2019/08/21 Python
python进程池实现的多进程文件夹copy器完整示例
2019/11/27 Python
Python统计时间内的并发数代码实例
2019/12/28 Python
深入浅出CSS3 background-clip,background-origin和border-image教程
2011/01/27 HTML / CSS
萨克斯第五大道英国:Saks Fifth Avenue英国
2019/04/01 全球购物
Groupon比利时官方网站:特卖和网上购物高达-70%
2019/08/09 全球购物
戴尔荷兰官方网站:Dell荷兰
2020/10/04 全球购物
护理工作感言
2014/01/16 职场文书
售后服务承诺书
2014/03/26 职场文书
交通事故赔偿协议书怎么写
2014/10/04 职场文书
2014年图书管理员工作总结
2014/12/01 职场文书
英文升职感谢信
2015/01/23 职场文书
2015年体检中心工作总结
2015/05/27 职场文书
2016大学生暑期三下乡心得体会
2016/01/23 职场文书
使用Navicat Premium工具将oracle数据库迁移到MySQL
2021/05/27 Oracle
基于PostgreSQL/openGauss 的分布式数据库解决方案
2021/12/06 PostgreSQL