PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python进行基础的函数式编程的教程
Mar 31 Python
利用Python中SocketServer 实现客户端与服务器间非阻塞通信
Dec 15 Python
用 Python 爬了爬自己的微信朋友(实例讲解)
Aug 25 Python
Python设计模式之门面模式简单示例
Jan 09 Python
如何利用python查找电脑文件
Apr 27 Python
django+xadmin+djcelery实现后台管理定时任务
Aug 14 Python
pycharm运行程序时在Python console窗口中运行的方法
Dec 03 Python
python 切换root 执行命令的方法
Jan 19 Python
python实现名片管理器的示例代码
Dec 17 Python
python interpolate插值实例
Jul 06 Python
pycharm专业版远程登录服务器的详细教程
Sep 15 Python
Pycharm创建python文件自动添加日期作者等信息(步骤详解)
Feb 03 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
为什么夜间收到的中波电台比白天多
2021/03/01 无线电
PHP中header和session_start前不能有输出原因分析
2013/01/11 PHP
PHP-Java-Bridge使用笔记
2014/09/22 PHP
PHP单例模式是什么 php实现单例模式的方法
2016/05/14 PHP
深入理解PHP JSON数组与对象
2016/07/19 PHP
Netbeans 8.2与PHP相关的新特性介绍
2016/10/08 PHP
Yii2.0实现生成二维码功能实例
2017/10/24 PHP
PHP设计模式之建造者模式定义与用法简单示例
2018/08/13 PHP
实例说明js脚本语言和php脚本语言的区别
2019/04/04 PHP
利用javascript数组长度循环数组内所有元素
2013/12/27 Javascript
代码获取历史上的今天发生的事
2014/04/11 Javascript
使用JavaScript制作一个简单的计数器的方法
2015/07/07 Javascript
jQuery模拟360浏览器切屏效果幻灯片(附demo源码下载)
2016/01/29 Javascript
浅析C/C++,Java,PHP,JavaScript,Json数组、对象赋值时最后一个元素后面是否可以带逗号
2016/03/22 Javascript
JS图片左右无缝隙滚动的实现(兼容IE,Firefox 遵循W3C标准)
2016/09/23 Javascript
Angularjs使用指令做表单校验的方法
2017/03/31 Javascript
JS处理数据四舍五入(tofixed与round的区别详解)
2017/10/26 Javascript
详解如何在Javascript中使用Object.freeze()
2020/10/18 Javascript
Python中的一些陷阱与技巧小结
2015/07/10 Python
13个最常用的Python深度学习库介绍
2017/10/28 Python
Python实现AI自动抠图实例解析
2020/03/05 Python
详解Python中string模块除去Str还剩下什么
2020/11/30 Python
捷克体育用品购物网站:D-sport
2017/12/28 全球购物
幼儿园保育员辞职信
2014/01/12 职场文书
结婚喜宴主持词
2014/03/14 职场文书
岗位职责说明书
2014/05/07 职场文书
村级换届选举方案
2014/05/10 职场文书
高职教师先进事迹材料
2014/08/24 职场文书
驳回起诉裁定书
2015/05/19 职场文书
学术研讨会主持词
2015/07/04 职场文书
2016春季运动会通讯稿
2015/07/18 职场文书
Python 如何实现文件自动去重
2021/06/02 Python
Java面试题冲刺第十五天--设计模式
2021/08/07 面试题
JavaScript高级程序设计之基本引用类型
2021/11/17 Javascript
Spring事务管理下synchronized锁失效问题的解决方法
2022/03/31 Java/Android
插件导致ECharts被全量引入的坑示例解析
2022/09/23 Javascript