pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的hashlib和base64加密模块使用实例
Sep 02 Python
python实现批量下载新浪博客的方法
Jun 15 Python
python解决pandas处理缺失值为空字符串的问题
Apr 08 Python
Python中的函数作用域
May 07 Python
Python中if elif else及缩进的使用简述
May 31 Python
对numpy中shape的深入理解
Jun 15 Python
详解python selenium 爬取网易云音乐歌单名
Mar 28 Python
python中的&amp;&amp;及||的实现示例
Aug 07 Python
python银行系统实现源码
Oct 25 Python
Python log模块logging记录打印用法解析
Jan 20 Python
Python 微信公众号文章爬取的示例代码
Nov 30 Python
Python第三方库安装缓慢的解决方法
Feb 06 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
计算php页面运行时间的函数介绍
2013/07/01 PHP
学习php设计模式 php实现原型模式(prototype)
2015/12/07 PHP
jQuery1.6 使用方法一
2011/11/23 Javascript
AngularJS基础学习笔记之简单介绍
2015/05/10 Javascript
jQuery 1.9.1源码分析系列(十五)之动画处理
2015/12/03 Javascript
JavaScript 闭包详细介绍
2016/09/28 Javascript
微信小程序开发之视频播放器 Video 弹幕 弹幕颜色自定义实例
2016/12/08 Javascript
Vue-Router实现页面正在加载特效方法示例
2017/02/12 Javascript
数组Array的一些方法(总结)
2017/02/17 Javascript
vue监听滚动事件实现滚动监听
2017/04/11 Javascript
Bootstrap提示框效果的实例代码
2017/07/12 Javascript
js获取文件里面的所有文件名(实例)
2017/10/17 Javascript
JavaScript实现单例模式实例分享
2017/12/22 Javascript
node.js+express+mySQL+ejs+bootstrop实现网站登录注册功能
2018/01/12 Javascript
webstorm中配置nodejs环境及npm的实例
2018/05/15 NodeJs
详解Angular6 热加载配置方案
2018/08/18 Javascript
JavaScript this关键字的深入详解
2021/01/14 Javascript
js实现验证码干扰(动态)
2021/02/23 Javascript
[39:08]完美世界DOTA2联赛PWL S3 LBZS vs CPG 第一场 12.12
2020/12/16 DOTA
Python批量修改文件后缀的方法
2014/01/26 Python
python编写的最短路径算法
2015/03/25 Python
使用python实现省市三级菜单效果
2016/01/20 Python
解决Django的request.POST获取不到内容的问题
2018/05/28 Python
selenium+python实现自动化登录的方法
2018/09/04 Python
python获取地震信息 微信实时推送
2019/06/18 Python
keras自动编码器实现系列之卷积自动编码器操作
2020/07/03 Python
python 如何对logging日志封装
2020/12/02 Python
canvas 如何绘制线段的实现方法
2018/07/12 HTML / CSS
苹果美国官方商城:Apple美国
2016/08/24 全球购物
美国韩国化妆品和护肤品购物网站:Beautytap
2018/07/29 全球购物
Linux开机引导的步骤是什么
2014/02/26 面试题
寄语是什么意思
2014/04/10 职场文书
对祖国的寄语大全
2014/04/11 职场文书
JavaScript原始值与包装对象的详细介绍
2021/05/11 Javascript
MySQL into_Mysql中replace与replace into用法案例详解
2021/09/14 MySQL
Oracle 死锁的检测查询及处理
2021/09/25 Oracle