PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中AND、OR的一个使用小技巧
Feb 18 Python
尝试使用Python多线程抓取代理服务器IP地址的示例
Nov 09 Python
用yum安装MySQLdb模块的步骤方法
Dec 15 Python
python扫描proxy并获取可用代理ip的实例
Aug 07 Python
Python实现将通信达.day文件读取为DataFrame
Dec 22 Python
Python图像处理之直线和曲线的拟合与绘制【curve_fit()应用】
Dec 26 Python
python 通过麦克风录音 生成wav文件的方法
Jan 09 Python
Python 中Django验证码功能的实现代码
Jun 20 Python
python性能测量工具cProfile使用解析
Sep 26 Python
pyecharts调整图例与各板块的位置间距实例
May 16 Python
Python和Bash结合在一起的方法
Nov 13 Python
浅谈Python中的正则表达式
Jun 28 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
ThinkPHP里用U方法调用js文件实例
2015/06/18 PHP
PHP简单检测网址是否能够正常打开的方法
2016/09/04 PHP
JavaScript与DropDownList 区别分析
2010/01/01 Javascript
解决ExtJS在chrome或火狐中正常显示在ie中不显示的浏览器兼容问题
2013/01/11 Javascript
使用js实现按钮控制文本框加1减1应用于小时+分钟
2013/12/09 Javascript
在Node.js应用中读写Redis数据库的简单方法
2015/06/30 Javascript
浅谈js多维数组和hash数组定义和使用
2016/07/27 Javascript
原生JS实现层叠轮播图
2017/05/17 Javascript
vue2.0+vuex+localStorage代办事项应用实现详解
2018/05/31 Javascript
vue中 v-for循环的用法详解
2020/02/19 Javascript
在Python的Tornado框架中实现简单的在线代理的教程
2015/05/02 Python
python入门教程 python入门神图一张
2018/03/05 Python
python实现人脸识别经典算法(一) 特征脸法
2018/03/13 Python
python 限制函数调用次数的实例讲解
2018/04/21 Python
Python实现微信消息防撤回功能的实例代码
2019/04/29 Python
python异常触发及自定义异常类解析
2019/08/06 Python
python+tifffile之tiff文件读写方式
2020/01/13 Python
Python中url标签使用知识点总结
2020/01/16 Python
python 中的[:-1]和[::-1]的具体使用
2020/02/13 Python
解决Python 异常TypeError: cannot concatenate 'str' and 'int' objects
2020/04/08 Python
python中实现词云图的示例
2020/12/19 Python
使用css3和jquery实现可伸缩搜索框
2014/02/12 HTML / CSS
用HTML5制作数字时钟的教程
2015/05/11 HTML / CSS
澳大利亚儿童和婴儿产品在线商店:Lime Tree Kids
2017/10/05 全球购物
美国婴儿用品及配件购买网站:Munchkin
2019/04/03 全球购物
C语言怎样定义和声明全局变量和函数最好
2013/11/26 面试题
C#中的验证控件有几种
2014/03/08 面试题
总经理助理的八要求
2013/11/12 职场文书
幼儿园门卫制度
2014/01/29 职场文书
化学系大学生自荐信范文
2014/03/01 职场文书
生产部厂长助理职位说明书
2014/03/03 职场文书
党员一句话承诺大全
2014/03/28 职场文书
小学三年级学生评语
2014/04/22 职场文书
防灾减灾标语
2014/10/07 职场文书
公司处罚决定书
2015/06/24 职场文书
2016年基层党支部书记公开承诺书
2016/03/25 职场文书