PyTorch 如何检查模型梯度是否可导


Posted in Python onJune 05, 2021

一、PyTorch 检查模型梯度是否可导

当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。

首先看一下官方文档中关于该函数的介绍:

PyTorch 如何检查模型梯度是否可导

PyTorch 如何检查模型梯度是否可导

可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:

Tensor需要是双精度浮点型且设置requires_grad = True

第一个例子:检查某一操作是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn
 
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

第二个例子:检查某一网络模型是否可导

from torch.autograd import gradcheck
import torch
import torch.nn as nn 
# 定义神经网络模型
class Net(nn.Module):
 
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(15, 30),
            nn.ReLU(),
            nn.Linear(30, 15),
            nn.ReLU(),
            nn.Linear(15, 1),
            nn.Sigmoid()
        )
 
    def forward(self, x):
        y = self.net(x)
        return y
 
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)

输出为:

Are the gradients correct: True

二、Pytorch求导

1.标量对矩阵求导

PyTorch 如何检查模型梯度是否可导

验证:

>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]])    # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True)  #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b)  # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad   #df/dX = a.dot(b^T)
tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])
>>>a.grad b.grad   # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1))  # a.dot(b^T)
    tensor([[ 2.,  3.,  4.],
    [ 4.,  6.,  8.],
    [ 6.,  9., 12.],
    [ 8., 12., 16.]])

2.矩阵对矩阵求导

PyTorch 如何检查模型梯度是否可导PyTorch 如何检查模型梯度是否可导

验证:

>>>A = torch.tensor([[1,2],[3,4.]])  #2*2矩阵
>>>X =  torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True)  # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
    [19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
    [6., 6., 6.]])

注意:

requires_grad为True的数组必须是float类型

进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
总结python爬虫抓站的实用技巧
Aug 09 Python
Python2中文处理纪要的实现方法
Mar 10 Python
Python sorted函数详解(高级篇)
Sep 18 Python
在Python中,不用while和for循环遍历列表的实例
Feb 20 Python
Django如何开发简单的查询接口详解
May 17 Python
pytorch sampler对数据进行采样的实现
Dec 31 Python
Python3如何在Windows和Linux上打包
Feb 25 Python
Matplotlib.pyplot 三维绘图的实现示例
Jul 28 Python
Python Web项目Cherrypy使用方法镜像
Nov 05 Python
如何用python 操作zookeeper
Dec 28 Python
Python 用户输入和while循环的操作
May 23 Python
Python编写车票订购系统 Python实现快递收费系统
Aug 14 Python
python-opencv 中值滤波{cv2.medianBlur(src, ksize)}的用法
解决Pytorch修改预训练模型时遇到key不匹配的情况
Jun 05 #Python
pytorch 预训练模型读取修改相关参数的填坑问题
Jun 05 #Python
解决pytorch 损失函数中输入输出不匹配的问题
Jun 05 #Python
Pytorch distributed 多卡并行载入模型操作
Jun 05 #Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
You might like
zend framework多模块多布局配置
2011/02/26 PHP
php面向对象与面向过程两种方法给图片添加文字水印
2015/08/26 PHP
PHP中each与list用法分析
2016/01/08 PHP
Yii视图操作之自定义分页实现方法
2016/07/14 PHP
Yii2实现同时搜索多个字段的方法
2016/08/10 PHP
Ajax实现对静态页面的文章访问统计功能示例
2016/10/10 PHP
PHP判断一个变量是否为整数、正整数的方法示例
2019/09/11 PHP
PHP序列化和反序列化深度剖析实例讲解
2020/12/29 PHP
javascript一点特殊用法
2008/05/28 Javascript
ASP.NET jQuery 实例2 (表单中使用回车在TextBox之间向下移动)
2012/01/13 Javascript
js中同步与异步处理的方法和区别总结
2013/12/25 Javascript
js获取表格的行数和列数的方法
2015/10/23 Javascript
非常棒的jQuery图片轮播效果
2016/04/17 Javascript
关于JS Lodop打印插件打印Bootstrap样式错乱问题的解决方案
2016/12/23 Javascript
解决Webpack 热部署检测不到文件变化的问题
2018/02/22 Javascript
layui的表单验证支持ajax判断用户名是否重复的实例
2019/09/06 Javascript
mpvue微信小程序开发之实现一个弹幕评论
2019/11/24 Javascript
解决vue scoped scss 无效的问题
2020/09/04 Javascript
重命名批处理python脚本
2013/04/05 Python
使用py2exe在Windows下将Python程序转为exe文件
2016/03/04 Python
Python使用tablib生成excel文件的简单实现方法
2016/03/16 Python
Python学习pygal绘制线图代码分享
2017/12/09 Python
30秒轻松实现TensorFlow物体检测
2018/03/14 Python
python+opencv实现阈值分割
2018/12/26 Python
使用CodeMirror实现Python3在线编辑器的示例代码
2019/01/14 Python
详解Django+uwsgi+Nginx上线最佳实战
2019/03/14 Python
Django框架获取form表单数据方式总结
2020/04/22 Python
太阳镜仓库,售价20美元或更少:Sunglass Warehouse
2016/09/28 全球购物
草莓网中国:StrawberryNet中国
2020/08/17 全球购物
实习生个人找工作的自我评价
2013/10/30 职场文书
四年大学生活的自我评价范文
2014/02/07 职场文书
书香家庭事迹材料
2014/05/09 职场文书
2016年中秋节晚会领导致辞
2015/11/26 职场文书
文明医院的标语集锦!
2019/07/24 职场文书
详解nodejs内置模块
2021/05/06 NodeJs
FP-growth算法发现频繁项集——构建FP树
2021/06/24 Python