PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python代码写的12306订票代码
Dec 20 Python
Python使用Matplotlib实现Logos设计代码
Dec 25 Python
python遍历文件夹下所有excel文件
Jan 03 Python
Python 查看文件的读写权限方法
Jan 23 Python
python中使用PIL制作并验证图片验证码
Mar 15 Python
基于Python实现扑克牌面试题
Dec 11 Python
利用Python计算KS的实例详解
Mar 03 Python
python3 实现口罩抽签的功能
Mar 11 Python
Django vue前后端分离整合过程解析
Nov 20 Python
python tkinter模块的简单使用
Apr 07 Python
python利用pandas分析学生期末成绩实例代码
Jul 09 Python
一起来学习Python的元组和列表
Mar 13 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
适用于php-5.2 的 php.ini 中文版[金步国翻译]
2011/04/17 PHP
thinkphp多表查询两表有重复相同字段的完美解决方法
2016/09/22 PHP
php 防止表单重复提交两种实现方法
2016/11/03 PHP
[企业公众号]升级到[企业微信]之后发送消息失败的解决方法
2017/06/30 PHP
在IE中调用javascript打开Excel的代码(downmoon原作)
2007/04/02 Javascript
JavaScript 继承的实现
2009/07/09 Javascript
jquery中使用$(#form).submit()重写提交表单无效原因分析及解决
2013/03/25 Javascript
Jquery仿淘宝京东多条件筛选可自行结合ajax加载示例
2013/08/28 Javascript
javascript获取鼠标点击元素对象(示例代码)
2013/12/20 Javascript
javaScript年份下拉列表框内容为当前年份及前后50年
2014/05/28 Javascript
浅析jQuery中调用ajax方法时在不同浏览器中遇到的问题
2014/06/11 Javascript
javascript中clipboardData对象用法详解
2015/05/13 Javascript
JavaScript File API实现文件上传预览
2016/02/02 Javascript
js原型链与继承解析(初体验)
2016/05/09 Javascript
node.js 中国天气预报 简单实现
2016/06/06 Javascript
JavaScript中ES6字符串扩展方法
2016/08/26 Javascript
javascript闭包功能与用法实例分析
2017/04/06 Javascript
Angular directive递归实现目录树结构代码实例
2017/05/05 Javascript
nodejs中Express与Koa2对比分析
2018/02/06 NodeJs
layui-table对返回的数据进行转变显示的实例
2019/09/04 Javascript
vue中nextTick用法实例
2019/09/11 Javascript
vue实现配置全局访问路径头(axios)
2019/11/01 Javascript
node.js事件轮询机制原理知识点
2019/12/22 Javascript
js实现带有动画的返回顶部
2020/08/09 Javascript
python批量修改文件后缀示例代码分享
2013/12/24 Python
python模块之StringIO使用示例
2015/04/08 Python
python中函数默认值使用注意点详解
2016/06/01 Python
详解如何管理多个Python版本和虚拟环境
2019/05/10 Python
django删除表重建的实现方法
2019/08/28 Python
Python实现Restful API的例子
2019/08/31 Python
分享8点超级有用的Python编程建议(推荐)
2019/10/13 Python
flask框架json数据的拿取和返回操作示例
2019/11/28 Python
婚庆公司计划书
2014/09/15 职场文书
交通安全教育心得体会
2016/01/15 职场文书
python基础之错误和异常处理
2021/10/24 Python
html用代码制作虚线框怎么做? dw制作虚线圆圈的技巧
2022/12/24 HTML / CSS