PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的super()方法使用简介
Aug 14 Python
python 自动化将markdown文件转成html文件的方法
Sep 23 Python
Python 转义字符详细介绍
Mar 21 Python
EM算法的python实现的方法步骤
Jan 02 Python
浅谈Pandas 排序之后索引的问题
Jun 07 Python
Python干货:分享Python绘制六种可视化图表
Aug 27 Python
对pandas通过索引提取dataframe的行方法详解
Feb 01 Python
面向对象学习之pygame坦克大战
Sep 11 Python
TensorFlow使用Graph的基本操作的实现
Apr 22 Python
python调用API接口实现登陆短信验证
May 10 Python
Jupyter notebook快速入门教程(推荐)
May 18 Python
Python提取PDF指定内容并生成新文件
Jun 09 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
PHP创建/删除/复制文件夹、文件
2016/05/03 PHP
php实现的网页版剪刀石头布游戏示例
2016/11/25 PHP
用JavaScript实现仿Windows关机效果
2007/03/10 Javascript
JQuery入门——用bind方法绑定事件处理函数应用介绍
2013/02/05 Javascript
JS实现点击按钮后框架内载入不同网页的方法
2015/05/05 Javascript
JavaScript实现获取某个元素相邻兄弟节点的prev与next方法
2016/01/25 Javascript
Javascript中函数名.length属性用法分析(对比arguments.length)
2016/09/16 Javascript
Bootstrap源码解读媒体对象、列表组和面板(10)
2016/12/26 Javascript
关于vue.js组件数据流的问题
2017/07/26 Javascript
vue + vuex todolist的实现示例代码
2018/03/09 Javascript
JavaScript 对引擎、运行时、调用堆栈的概述理解
2018/10/22 Javascript
详解使用JWT实现单点登录(完全跨域方案)
2019/08/02 Javascript
[34:56]Ti4冒泡赛LGD vs Liquid 1
2014/07/14 DOTA
python学习 流程控制语句详解
2016/06/01 Python
Python3 伪装浏览器的方法示例
2017/11/23 Python
Python机器学习之K-Means聚类实现详解
2018/02/22 Python
python通过配置文件共享全局变量的实例
2019/01/11 Python
Python数据类型之Set集合实例详解
2019/05/07 Python
pyqt 实现QlineEdit 输入密码显示成圆点的方法
2019/06/24 Python
python3 自动识别usb连接状态,即对usb重连的判断方法
2019/07/03 Python
在Python中获取操作系统的进程信息
2019/08/27 Python
Python 实现try重新执行
2019/12/21 Python
python实现ip地址的包含关系判断
2020/02/07 Python
css3实现可拖动的魔方3d效果
2019/05/07 HTML / CSS
英国工艺品购物网站:Minerva Crafts
2018/01/29 全球购物
DC Shoes澳大利亚官方网上商店:购买DC鞋子
2019/10/25 全球购物
Java面试题:请说出如下代码的输出结果
2013/04/22 面试题
水利公司纪检监察自我鉴定
2014/02/25 职场文书
我的梦中国梦演讲稿
2014/04/23 职场文书
英文演讲稿开场白
2014/08/25 职场文书
2015年安全生产工作总结范文
2015/04/02 职场文书
不同意离婚上诉状
2015/05/23 职场文书
ORACLE数据库对long类型字段进行模糊匹配的解决思路
2021/04/07 Oracle
Python机器学习算法之决策树算法的实现与优缺点
2021/05/13 Python
微信小程序实现聊天室功能
2021/06/14 Javascript
Python实战之OpenCV实现猫脸检测
2021/06/26 Python