pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python益智游戏计算汉诺塔问题示例
Mar 05 Python
解读Python编程中的命名空间与作用域
Oct 16 Python
Python编程中装饰器的使用示例解析
Jun 20 Python
Python实现带百分比的进度条
Jun 28 Python
Python数据分析之真实IP请求Pandas详解
Nov 18 Python
详谈Python2.6和Python3.0中对除法操作的异同
Apr 28 Python
Python实现的插入排序算法原理与用法实例分析
Nov 22 Python
用Python分析3天破10亿的《我不是药神》到底神在哪?
Jul 12 Python
Flask实现图片的上传、下载及展示示例代码
Aug 03 Python
python 线性回归分析模型检验标准--拟合优度详解
Feb 24 Python
Python selenium页面加载慢超时的解决方案
Mar 18 Python
python zip,lambda,map函数代码实例
Apr 04 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
我的论坛源代码(四)
2006/10/09 PHP
PHP实现基于回溯法求解迷宫问题的方法详解
2017/08/17 PHP
关于js new Date() 出现NaN 的分析
2012/10/23 Javascript
JavaScript获取onclick、onchange等事件值的代码
2013/07/22 Javascript
jQuery Mobile的loading对话框显示/隐藏方法分享
2013/11/26 Javascript
基于jquery实现的可编辑下拉框实现代码
2014/08/02 Javascript
我用的一些Node.js开发工具、开发包、框架等总结
2014/09/25 Javascript
浅谈JavaScript数据类型
2015/03/03 Javascript
JavaScript在浏览器标题栏上显示当前日期和时间的方法
2015/03/19 Javascript
通过js获取上传的图片信息(临时保存路径,名称,大小)然后通过ajax传递给后端的方法
2015/10/01 Javascript
举例讲解JavaScript中将数组元素转换为字符串的方法
2015/10/25 Javascript
第九篇Bootstrap导航菜单创建步骤详解
2016/06/21 Javascript
vue基于Vue2.0和高德地图的地图组件实例
2017/04/28 Javascript
利用forever和pm2部署node.js项目过程
2017/05/10 Javascript
vue 运用mock数据的示例代码
2017/11/07 Javascript
react+ant design实现Table的增、删、改的示例代码
2018/12/27 Javascript
jQuery动态生成的元素绑定事件操作实例分析
2019/05/04 jQuery
jquery.pager.js实现分页效果
2019/07/29 jQuery
浅谈layui分页控件field参数接收对象的问题
2019/09/20 Javascript
ant design vue datepicker日期选择器中文化操作
2020/10/28 Javascript
JavaScript Dom实现轮播图原理和实例
2021/02/19 Javascript
[01:07:47]Secret vs Optic Supermajor 胜者组 BO3 第一场 6.4
2018/06/05 DOTA
centos系统升级python 2.7.3
2014/07/03 Python
在Python的Flask框架中实现全文搜索功能
2015/04/20 Python
简单的Apache+FastCGI+Django配置指南
2015/07/22 Python
CentOS安装pillow报错的解决方法
2016/01/27 Python
Python处理CSV与List的转换方法
2018/04/19 Python
python实现日志按天分割
2019/07/22 Python
Python使用扩展库pywin32实现批量文档打印实例
2020/04/09 Python
python logging通过json文件配置的步骤
2020/04/27 Python
纯HTML5+CSS3制作图片旋转
2016/01/12 HTML / CSS
房地产销售计划书
2014/01/10 职场文书
承诺书的格式范文
2014/03/28 职场文书
人力资源管理毕业求职信
2014/08/05 职场文书
军训阅兵新闻稿
2015/07/17 职场文书
竞聘书的秘诀
2019/04/02 职场文书