PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
两个使用Python脚本操作文件的小示例分享
Aug 27 Python
python urllib urlopen()对象方法/代理的补充说明
Jun 29 Python
tensorflow构建BP神经网络的方法
Mar 12 Python
Python subprocess模块常见用法分析
Jun 12 Python
解决Tensorflow使用pip安装后没有model目录的问题
Jun 13 Python
python使用Plotly绘图工具绘制气泡图
Apr 01 Python
提升Python程序性能的7个习惯
Apr 14 Python
python 字符串追加实例
Jul 20 Python
pytorch索引查找 index_select的例子
Aug 18 Python
解决import tensorflow as tf 出错的原因
Apr 16 Python
5 分钟读懂Python 中的 Hook 钩子函数
Dec 09 Python
深度学习tensorflow基础mnist
Apr 14 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
利用yahoo汇率接口实现实时汇率转换示例 汇率转换器
2014/01/14 PHP
php读取XML的常见方法实例总结
2017/04/25 PHP
php操作redis数据库常见方法实例总结
2020/02/20 PHP
实现局部遮罩与关闭原理及代码
2013/02/04 Javascript
javascript版的in_array函数(判断数组中是否存在特定值)
2014/05/09 Javascript
Javascript中的getUTCDay()方法使用详解
2015/06/10 Javascript
javascript中undefined与null的区别
2015/08/16 Javascript
浅析javascript的return语句
2015/12/15 Javascript
在JavaScript中模拟类(class)及类的继承关系
2016/05/20 Javascript
快速解决js动态改变dom元素属性后页面及时渲染的问题
2016/07/06 Javascript
详解js中Json的语法与格式
2016/11/22 Javascript
获取JavaScript异步函数的返回值
2016/12/21 Javascript
微信小程序实现皮肤功能(夜间模式)
2017/06/18 Javascript
vue组件学习教程
2017/09/09 Javascript
jQuery实现监听下拉框选中内容发生改变操作示例
2018/07/13 jQuery
NodeJs项目中关闭ESLint的方法
2018/08/09 NodeJs
Node 搭建一个静态资源服务器的实现
2019/05/20 Javascript
vue 开发之路由配置方法详解
2019/12/02 Javascript
python删除某个字符
2018/03/19 Python
django manage.py扩展自定义命令方法
2018/05/27 Python
Python实现通过继承覆盖方法示例
2018/07/02 Python
paramiko使用tail实时获取服务器的日志输出详解
2020/12/06 Python
使用HTML5的表单验证的简单示例
2015/09/09 HTML / CSS
联想香港官方网站及网店:Lenovo香港
2018/04/13 全球购物
一个大学生十年的职业规划
2014/01/17 职场文书
2014年党员公开承诺书范文
2014/03/28 职场文书
演讲稿格式
2014/04/30 职场文书
勿忘国耻9.18演讲稿(经典篇)
2014/09/14 职场文书
人身损害赔偿协议书格式
2014/11/01 职场文书
学校节水倡议书
2015/04/29 职场文书
2015年企业团支部工作总结
2015/05/21 职场文书
初中团支书竞选稿
2015/11/21 职场文书
农村房屋租赁合同(范本)
2019/07/23 职场文书
浅谈redis缓存在项目中的使用
2021/05/20 Redis
Python类方法总结讲解
2021/07/26 Python
详解解Django 多对多表关系的三种创建方式
2021/08/23 Python