PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python之wxPython菜单使用详解
Sep 28 Python
python使用Pycharm创建一个Django项目
Mar 05 Python
用python处理图片实现图像中的像素访问
May 04 Python
tensorflow 获取模型所有参数总和数量的方法
Jun 14 Python
Django contenttypes 框架详解(小结)
Aug 13 Python
只需7行Python代码玩转微信自动聊天
Jan 27 Python
Pytorch加载部分预训练模型的参数实例
Aug 18 Python
利用setuptools打包python程序的方法步骤
Jan 18 Python
python生成并处理uuid的实现方式
Mar 03 Python
Pycharm连接远程服务器过程图解
Apr 30 Python
JAVA及PYTHON质数计算代码对比解析
Jun 10 Python
Python中X[:,0]和X[:,1]的用法
May 10 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
MYSQL 小技巧 -- LAST_INSERT_ID
2009/11/24 PHP
php实现utf-8转unicode函数分享
2015/01/06 PHP
PHP数组内存利用率低和弱类型详细解读
2017/08/10 PHP
PHP中关于php.ini参数优化详解
2020/02/28 PHP
TypeScript 学习笔记之基本类型
2015/06/19 Javascript
jQuery简单设置文本框回车事件的方法
2016/08/01 Javascript
JavaScript中的ajax功能的概念和示例详解
2016/10/17 Javascript
JS简单判断滚动条的滚动方向实现方法
2017/04/28 Javascript
VUE axios发送跨域请求需要注意的问题
2017/07/06 Javascript
ng-events类似ionic中Events的angular全局事件
2018/09/05 Javascript
vue3.0 CLI - 3.2 路由的初级使用教程
2018/09/20 Javascript
解决vue-cli webpack打包后加载资源的路径问题
2018/09/25 Javascript
jquery获取img的src值实例介绍
2019/01/16 jQuery
node.js express框架实现文件上传与下载功能实例详解
2019/10/15 Javascript
python django集成cas验证系统
2014/07/14 Python
Python编程之微信推送模板消息功能示例
2017/08/21 Python
Python基于TCP实现会聊天的小机器人功能示例
2018/04/09 Python
在PyCharm中三步完成PyPy解释器的配置的方法
2018/10/29 Python
详解Python 调用C# dll库最简方法
2019/06/20 Python
python yield关键词案例测试
2019/10/15 Python
python使用配置文件过程详解
2019/12/28 Python
python如何写出表白程序
2020/06/01 Python
简单了解Django项目应用创建过程
2020/07/06 Python
pycharm 添加解释器的方法步骤
2020/08/31 Python
CSS3实现粒子旋转伸缩加载动画
2016/04/22 HTML / CSS
使用CSS3实现多列布局与多背景的技巧
2016/02/29 HTML / CSS
英国鞋类及配饰零售商:Kurt Geiger
2017/02/04 全球购物
美国殿堂级滑板、冲浪、滑雪服装品牌:Volcom(钻石)
2017/04/20 全球购物
Linux如何为某个操作添加别名
2015/02/05 面试题
银行青年文明号事迹材料
2014/05/31 职场文书
英语专业求职信
2014/07/08 职场文书
2014年图书室工作总结
2014/12/09 职场文书
2015年推广普通话演讲稿
2015/03/20 职场文书
教师节主题班会教案
2015/08/17 职场文书
2016中秋节问候语
2015/11/11 职场文书
win10输入法不见了只能打出字母怎么解决?
2022/08/05 数码科技