PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现TCP/IP协议下的端口转发及重定向示例
Jun 14 Python
基于Python实现对PDF文件的OCR识别
Aug 05 Python
Python实现SSH远程登陆,并执行命令的方法(分享)
May 08 Python
Python操作使用MySQL数据库的实例代码
May 25 Python
Python实现的手机号归属地相关信息查询功能示例
Jun 08 Python
Python装饰器模式定义与用法分析
Aug 06 Python
python try except返回异常的信息字符串代码实例
Aug 15 Python
Python-Flask:动态创建表的示例详解
Nov 22 Python
浅谈keras的深度模型训练过程及结果记录方式
Jan 24 Python
Python龙贝格法求积分实例
Feb 29 Python
Python连接HDFS实现文件上传下载及Pandas转换文本文件到CSV操作
Jun 06 Python
Python如何利用Har文件进行遍历指定字典替换提交的数据详解
Nov 05 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
阿拉伯的咖啡与水烟
2021/03/03 咖啡文化
Apache环境下PHP利用HTTP缓存协议原理解析及应用分析
2010/02/16 PHP
用穿越火线快速入门php面向对象
2012/02/22 PHP
PHP中的生成XML文件的4种方法分享
2012/10/06 PHP
yii2.0使用Plupload实现带缩放功能的多图上传
2015/12/22 PHP
微信支付扫码支付php版
2016/07/22 PHP
tp5框架基于ajax实现异步删除图片的方法示例
2020/02/10 PHP
用javascript实现无刷新更新数据的详细步骤 asp
2006/12/26 Javascript
Jquery 表格合并的问题分享
2011/09/17 Javascript
jquery获取元素值的方法(常见的表单元素)
2013/11/15 Javascript
js实现表单多按钮提交action的处理方法
2015/10/24 Javascript
JavaScript html5 canvas绘制时钟效果
2016/03/01 Javascript
基于Phantomjs生成PDF的实现方法
2016/11/07 Javascript
解析JavaScript实现DDoS攻击原理与保护措施
2016/12/26 Javascript
JS将unicode码转中文方法
2017/05/08 Javascript
理解Angular的providers给Http添加默认headers
2017/07/04 Javascript
js笔试题-接收get请求参数
2019/06/15 Javascript
详解JSON.stringify()的5个秘密特性
2020/05/26 Javascript
Python判断操作系统类型代码分享
2014/11/22 Python
Window环境下Scrapy开发环境搭建
2018/11/18 Python
Python不同目录间进行模块调用的实现方法
2019/01/29 Python
Python基于OpenCV实现人脸检测并保存
2019/07/23 Python
【HTML5】3D模型--百行代码实现旋转立体魔方实例
2016/12/16 HTML / CSS
CAT鞋美国官网:CAT Footwear
2017/11/27 全球购物
局域网标准
2016/09/10 面试题
介绍一下gcc特性
2015/10/31 面试题
营销专业应届生求职信
2013/11/26 职场文书
英文导游欢迎词
2014/01/11 职场文书
项目合作协议书
2014/04/16 职场文书
体育运动口号
2014/06/09 职场文书
2014年小学重阳节活动策划方案
2014/09/16 职场文书
大学生思想道德自我评价
2015/03/09 职场文书
人事行政助理岗位职责
2015/04/11 职场文书
2016年第32个教师节红领巾广播稿
2015/12/18 职场文书
大学生村官驻村工作心得体会
2016/01/23 职场文书
使用 DataAnt 监控 Apache APISIX的原理解析
2022/07/07 Servers