pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python生成日历实例解析
Aug 21 Python
python实现登陆知乎获得个人收藏并保存为word文件
Mar 16 Python
Python实现简单拆分PDF文件的方法
Jul 30 Python
Python爬取十篇新闻统计TF-IDF
Jan 03 Python
Python实现的简单计算器功能详解
Aug 25 Python
python学生信息管理系统(完整版)
Apr 05 Python
numpy和pandas中数组的合并、拉直和重塑实例
Jun 28 Python
python实现网站用户名密码自动登录功能
Aug 09 Python
python GUI库图形界面开发之PyQt5窗口类QMainWindow详细使用方法
Feb 26 Python
Python logging日志模块 配置文件方式
Jul 12 Python
Python 带星号(* 或 **)的函数参数详解
Feb 23 Python
用Python实现一个打字速度测试工具来测试你的手速
May 28 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
php中3种方法统计字符串中每种字符的个数并排序
2012/08/27 PHP
PHP strip_tags()去除HTML、XML以及PHP的标签介绍
2014/02/18 PHP
linux下编译安装memcached服务
2014/08/03 PHP
php中通过DirectoryIterator删除整个目录的方法
2015/03/13 PHP
php通过递归方式复制目录和子目录的方法
2015/03/13 PHP
ThinkPHP文件缓存类代码分享
2015/04/22 PHP
二级域名转向类
2006/11/09 Javascript
文档对象模型DOM通俗讲解
2013/11/01 Javascript
jquery(hide方法)隐藏指定元素实例
2013/11/11 Javascript
jquery showModelDialog的使用方法示例详解
2013/11/19 Javascript
javascript解析json实例详解
2014/11/05 Javascript
JavaScript实现刷新不重记的倒计时
2016/08/10 Javascript
在html中引入外部js文件,并调用带参函数的方法
2016/10/31 Javascript
微信小程序 动态的设置图片的高度和宽度详解及实例代码
2017/02/24 Javascript
详解Vue微信授权登录前后端分离较为优雅的解决方案
2018/06/29 Javascript
JQuery animate动画应用示例
2019/05/14 jQuery
vue+Element实现搜索关键字高亮功能
2019/05/28 Javascript
vue服务端渲染操作简单入门实例分析
2019/08/28 Javascript
如何通过vscode运行调试javascript代码
2020/07/24 Javascript
小程序自定义弹框效果
2020/11/16 Javascript
Python中dictionary items()系列函数的用法实例
2014/08/21 Python
Python实现的一个找零钱的小程序代码分享
2014/08/25 Python
Python实现的生成格雷码功能示例
2018/01/24 Python
对Python 数组的切片操作详解
2018/07/02 Python
总结python中pass的作用
2019/02/27 Python
PyQt5实现暗黑风格的计时器
2019/07/29 Python
Python xlwings插入Excel图片的实现方法
2021/02/26 Python
合作协议书范文
2014/08/20 职场文书
施工安全协议书范本
2014/09/26 职场文书
企业务虚会发言材料
2014/10/20 职场文书
重阳节慰问信
2015/02/15 职场文书
2015年环卫处个人工作总结
2015/07/27 职场文书
致毕业季:你如何做好自己的职业生涯规划书?
2019/07/01 职场文书
Mysql 设置boolean类型的操作
2021/06/04 MySQL
go开发alertmanger实现钉钉报警
2021/07/16 Golang
小喇叭开始广播了! 四十多年前珍贵老照片
2022/05/09 无线电