PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python提取html文件中的特定数据的实现代码
Mar 24 Python
python中使用正则表达式的连接符示例代码
Oct 10 Python
selenium+python自动化测试环境搭建步骤
Jun 03 Python
解决Python设置函数调用超时,进程卡住的问题
Aug 08 Python
决策树剪枝算法的python实现方法详解
Sep 18 Python
python抓取多种类型的页面方法实例
Nov 20 Python
python实现同一局域网下传输图片
Mar 20 Python
将pymysql获取到的数据类型是tuple转化为pandas方式
May 15 Python
Python如何实现的二分查找算法
May 27 Python
python cv2.resize函数high和width注意事项说明
Jul 05 Python
基于Pytorch版yolov5的滑块验证码破解思路详解
Feb 25 Python
python基础之文件操作
Oct 24 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
PHP 编写大型网站问题集
2010/05/07 PHP
解析php中curl_multi的应用
2013/07/17 PHP
php实现12306火车票余票查询和价格查询(12306火车票查询)
2014/01/14 PHP
php Imagick获取图片RGB颜色值
2014/07/28 PHP
如何让CI框架支持service层
2014/10/29 PHP
php中isset与empty函数的困惑与用法分析
2019/07/05 PHP
JavaScript小技巧 2.5 则
2010/09/12 Javascript
Android中资源文件(非代码部分)的使用概览
2012/12/18 Javascript
js 窗口抖动示例
2013/09/04 Javascript
弹出窗口并且此窗口带有半透明的遮罩层效果
2014/03/13 Javascript
点击表单提交时出现jQuery没有权限的解决方法
2014/07/23 Javascript
node.js中的fs.read方法使用说明
2014/12/17 Javascript
Javascript基础教程之定义和调用函数
2015/01/18 Javascript
php基于redis处理session的方法
2016/03/14 Javascript
详解Vue前端生产环境发布配置实战篇
2019/05/07 Javascript
详解Python装饰器由浅入深
2016/12/09 Python
Python 备份程序代码实现
2017/03/06 Python
python得到单词模式的示例
2018/10/15 Python
对python制作自己的数据集实例讲解
2018/12/12 Python
django之使用celery-把耗时程序放到celery里面执行的方法
2019/07/12 Python
Pytorch.nn.conv2d 过程验证方式(单,多通道卷积过程)
2020/01/03 Python
Python实现子类调用父类的初始化实例
2020/03/12 Python
css3 线性渐变和径向渐变示例附图
2014/04/08 HTML / CSS
新西兰优惠网站:Treat Me
2019/07/04 全球购物
物业电工岗位职责
2013/11/20 职场文书
《挑山工》的教学反思
2014/02/16 职场文书
校庆筹备方案
2014/03/30 职场文书
高三学生评语大全
2014/04/25 职场文书
党课培训心得体会
2014/09/02 职场文书
2014年中学生检讨书大全
2014/10/09 职场文书
2014年社区妇联工作总结
2014/12/02 职场文书
2015年电厂工作总结范文
2015/05/13 职场文书
重阳节简报
2015/07/20 职场文书
2015元旦感言
2015/12/09 职场文书
2016感恩母亲节校园广播稿
2015/12/17 职场文书
微信小程序中wxs文件的一些妙用分享
2022/02/18 Javascript