pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
测试、预发布后用python检测网页是否有日常链接
Jun 03 Python
探究Python中isalnum()方法的使用
May 18 Python
python集合用法实例分析
May 30 Python
以视频爬取实例讲解Python爬虫神器Beautiful Soup用法
Jan 20 Python
Python使用gensim计算文档相似性
Apr 10 Python
AI人工智能 Python实现人机对话
Nov 13 Python
python3实现跳一跳点击跳跃
Jan 08 Python
python实现12306抢票及自动邮件发送提醒付款功能
Mar 08 Python
Python使用type关键字创建类步骤详解
Jul 23 Python
浅析python中while循环和for循环
Nov 19 Python
Python 在 VSCode 中使用 IPython Kernel 的方法详解
Sep 05 Python
Python opencv缺陷检测的实现及问题解决
Apr 24 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
利用PHP制作简单的内容采集器的代码
2007/11/28 PHP
深入php 正则表达式的学习探讨
2013/06/06 PHP
CodeIgniter图像处理类的深入解析
2013/06/17 PHP
PHP获取Exif缩略图的方法
2015/07/13 PHP
Yii 框架控制器创建使用及控制器响应操作示例
2019/10/14 PHP
JS中数组Array的用法示例介绍
2014/02/20 Javascript
jquery文本框中的事件应用以输入邮箱为例
2014/05/06 Javascript
jquery bind(click)传参让列表中每行绑定一个事件
2014/08/06 Javascript
分步解析JavaScript实现tab选项卡自动切换功能
2016/01/25 Javascript
使用Bootstrap框架制作查询页面的界面实例代码
2016/05/27 Javascript
深入理解 JavaScript 中的 JSON
2017/04/06 Javascript
jQuery实现上传图片前预览效果功能
2017/08/03 jQuery
JavaScript中的return布尔值的用法和原理解析
2017/08/14 Javascript
JS基于ES6新特性async await进行异步处理操作示例
2019/02/02 Javascript
Electron+vue从零开始打造一个本地播放器的方法示例
2020/10/27 Javascript
[03:03]DOTA2校园争霸赛 济南城市决赛欢乐发奖活动
2013/10/21 DOTA
[01:06:43]完美世界DOTA2联赛PWL S3 PXG vs GXR 第二场 12.19
2020/12/24 DOTA
跟老齐学Python之总结参数的传递
2014/10/10 Python
Python获取指定字符前面的所有字符方法
2018/05/02 Python
使用python实现快速搭建简易的FTP服务器
2018/09/12 Python
Python读取YAML文件过程详解
2019/12/30 Python
Python爬虫实现百度翻译功能过程详解
2020/05/29 Python
Python ellipsis 的用法详解
2020/11/20 Python
Python实现网络聊天室的示例代码(支持多人聊天与私聊)
2021/01/27 Python
美国知名男士服饰品牌:Brooks Brothers(布克兄弟)
2016/08/25 全球购物
迪拜领先运动补剂零售品牌中文站:Sporter商城
2019/08/20 全球购物
编写一个 C 函数,该函数在一个字符串中找到可能的最长的子字符串,且该字符串是由同一字符组成的
2015/07/23 面试题
介绍一下Linux文件的记录形式
2013/09/29 面试题
给幼儿园老师的表扬信
2014/01/19 职场文书
教师党员先进性教育自我剖析材料思想汇报
2014/09/24 职场文书
JS一分钟在github+Jekyll的博客中添加访问量功能的实现
2021/04/03 Javascript
面试被问select......for update会锁表还是锁行
2021/11/11 MySQL
Python Matplotlib库实现画局部图
2021/11/17 Python
Consul在linux环境的集群部署
2022/04/08 Servers
利用Python实时获取steam特惠游戏数据
2022/06/25 Python
spring boot实现文件上传
2022/08/14 Java/Android