PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
自动化Nginx服务器的反向代理的配置方法
Jun 28 Python
python+VTK环境搭建及第一个简单程序代码
Dec 13 Python
pandas通过loc生成新的列方法
Nov 28 Python
Python 实现数据结构-循环队列的操作方法
Jul 17 Python
python3 反射的四种基本方法解析
Aug 26 Python
python编写猜数字小游戏
Oct 06 Python
python使用pygame实现笑脸乒乓球弹珠球游戏
Nov 25 Python
PyCharm中Matplotlib绘图不能显示UI效果的问题解决
Mar 12 Python
python中取绝对值简单方法总结
Jul 24 Python
通过代码简单了解django model序列化作用
Nov 12 Python
常用的Python代码调试工具总结
Jun 23 Python
Python 可迭代对象 iterable的具体使用
Aug 07 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
php 操作excel文件的方法小结
2009/12/31 PHP
在PHP中实现Javascript的escape()函数代码
2010/08/08 PHP
php将gd生成的图片缓存到memcache的小例子
2013/06/05 PHP
php遍历文件夹所有文件子文件夹函数代码
2013/11/27 PHP
php生成缩略图填充白边(等比缩略图方案)
2013/12/25 PHP
PHP页面转UTF-8中文编码乱码的解决办法
2015/10/20 PHP
php自动加载代码实例详解
2021/02/26 PHP
一个可拖拽列宽表格实例演示
2012/11/26 Javascript
javascript实现跳转菜单的具体方法
2013/07/05 Javascript
javascript禁用Tab键脚本实例
2013/11/22 Javascript
JQuery中DOM事件合成用法实例分析
2015/06/13 Javascript
jQuery实现的瀑布流加载效果示例
2016/09/13 Javascript
mvc 、bootstrap 结合分布式图简单实现分页
2016/10/10 Javascript
JavaScript的词法结构精华篇
2018/10/17 Javascript
vue自定义键盘信息、监听数据变化的方法示例【基于vm.$watch】
2019/03/16 Javascript
[50:29]2014 DOTA2华西杯精英邀请赛 5 24 DK VS iG
2014/05/26 DOTA
[05:11]TI9战队采访——VIRTUSPRO
2019/08/22 DOTA
python使用PIL缩放网络图片并保存的方法
2015/04/24 Python
python 使用get_argument获取url query参数
2017/04/28 Python
Python实现Smtplib发送带有各种附件的邮件实例
2017/06/05 Python
解决python Markdown模块乱码的问题
2019/02/14 Python
对Django中的权限和分组管理实例讲解
2019/08/16 Python
python取均匀不重复的随机数方式
2019/11/27 Python
python 实现将小图片放到另一个较大的白色或黑色背景图片中
2019/12/12 Python
python中68个内置函数的总结与介绍
2020/02/24 Python
Python callable内置函数原理解析
2020/03/05 Python
pycharm配置python 设置pip安装源为豆瓣源
2021/02/05 Python
HTML5学习笔记之History API
2015/02/26 HTML / CSS
中国首家奢侈品O2O网购平台:第五大道奢侈品网
2017/12/14 全球购物
美体小铺波兰官方网站:The Body Shop波兰
2019/09/03 全球购物
水毁工程实施方案
2014/04/01 职场文书
授权委托书格式模板
2014/04/03 职场文书
客服专员岗位职责
2015/02/10 职场文书
昆虫记读书笔记
2015/06/26 职场文书
Python+Tkinter制作专属图形化界面
2022/04/01 Python
解决Python保存文件名太长OSError: [Errno 36] File name too long
2022/05/11 Python