pytorch动态网络以及权重共享实例


Posted in Python onJanuary 06, 2020

pytorch 动态网络+权值共享

pytorch以动态图著称,下面以一个栗子来实现动态网络和权值共享技术:

# -*- coding: utf-8 -*-
import random
import torch


class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    这里构造了几个向前传播过程中用到的线性函数
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.

    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.

    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    这里中间层每次向前过程中都是随机添加0-3层,而且中间层都是使用的同一个线性层,这样计算时,权值也是用的同一个。
    """
    h_relu = self.input_linear(x).clamp(min=0)
    for _ in range(random.randint(0, 3)):
      h_relu = self.middle_linear(h_relu).clamp(min=0)
    y_pred = self.output_linear(h_relu)
    return y_pred


    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)

    # Construct our model by instantiating the class defined above
    model = DynamicNet(D_in, H, D_out)

    # Construct our loss function and an Optimizer. Training this strange model with
    # vanilla stochastic gradient descent is tough, so we use momentum
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    for t in range(500):
      # Forward pass: Compute predicted y by passing x to the model
      y_pred = model(x)

      # Compute and print loss
      loss = criterion(y_pred, y)
      print(t, loss.item())

      # Zero gradients, perform a backward pass, and update the weights.
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

这个程序实际上是一种RNN结构,在执行过程中动态的构建计算图

References: Pytorch Documentations.

以上这篇pytorch动态网络以及权重共享实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python通过定义一个类实例作为ftp回调方法
May 04 Python
PyQt5每天必学之布局管理
Apr 19 Python
转换科学计数法的数值字符串为decimal类型的方法
Jul 16 Python
django_orm查询性能优化方法
Aug 20 Python
python and or用法详解
Jun 26 Python
Django CSRF跨站请求伪造防护过程解析
Jul 31 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 Python
Matplotlib自定义坐标轴刻度的实现示例
Jun 18 Python
keras实现theano和tensorflow训练的模型相互转换
Jun 19 Python
python 装饰器的实际作用有哪些
Sep 07 Python
Python中requests库的用法详解
Jun 05 Python
python基础之//、/与%的区别详解
Jun 10 Python
selenium中get_cookies()和add_cookie()的用法详解
Jan 06 #Python
pytorch中的自定义反向传播,求导实例
Jan 06 #Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
You might like
php数组函数序列之asort() - 对数组的元素值进行升序排序,保持索引关系
2011/11/02 PHP
php无限极分类递归排序实现方法
2014/11/11 PHP
php自动加载方式集合
2016/04/04 PHP
javascript 三种编解码方式
2010/02/01 Javascript
jQuery 幻灯片插件(带缩略图功能)
2011/01/24 Javascript
js查找某元素中的所有图片地址的方法
2014/01/16 Javascript
js 动态修改css文件用到了cssRule
2014/08/20 Javascript
七夕情人节丘比特射箭小游戏
2015/08/20 Javascript
Bootstrap表单布局
2016/07/19 Javascript
JS button按钮实现submit按钮提交效果
2016/11/01 Javascript
Reactjs实现通用分页组件的实例代码
2017/01/19 Javascript
for循环 + setTimeout 结合一些示例(前端面试题)
2017/08/30 Javascript
React Native 通告消息竖向轮播组件的封装
2020/08/25 Javascript
bootstrap table支持高度百分比的实例代码
2018/02/28 Javascript
Javascript 之封装(Package)
2018/09/14 Javascript
vue实现同一个页面可以有多个router-view的方法
2018/09/20 Javascript
微信小程序动画(Animation)的实现及执行步骤
2018/10/28 Javascript
JavaScript判断数组类型的方法
2019/10/23 Javascript
Openlayers3实现车辆轨迹回放功能
2020/09/29 Javascript
vue的hash值原理也是table切换实例代码
2020/12/14 Vue.js
Python的lambda匿名函数的简单介绍
2013/04/25 Python
Python对象体系深入分析
2014/10/28 Python
python模块之StringIO使用示例
2015/04/08 Python
Django中对通过测试的用户进行限制访问的方法
2015/07/23 Python
Tensorflow实现神经网络拟合线性回归
2019/07/19 Python
python3.7+selenium模拟淘宝登录功能的实现
2020/05/26 Python
Python实现读取并写入Excel文件过程解析
2020/05/27 Python
python利用tkinter实现图片格式转换的示例
2020/09/28 Python
新西兰领先的鞋类和靴子网上商城:Merchant 1948
2017/09/08 全球购物
旧时光糖果:Old Time Candy
2018/02/05 全球购物
贪睡宠物用品:Snoozer Pet Products
2020/02/04 全球购物
财务管理专业毕业生求职信范文
2013/09/21 职场文书
师德承诺书
2015/01/20 职场文书
你离财务总监还有多远?速览CFO的岗位职责
2019/11/18 职场文书
在 SQL 语句中处理 NULL 值的方法
2021/06/07 SQL Server
python中的3种定义类方法
2021/11/27 Python