pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python利用requests库进行接口测试的方法详解
Jul 06 Python
Django如何自定义分页
Sep 25 Python
python3 实现一行输入,空格隔开的示例
Nov 14 Python
python读取Excel表格文件的方法
Sep 02 Python
Flask框架路由和视图用法实例分析
Nov 07 Python
python 变量初始化空列表的例子
Nov 28 Python
基于Python数据分析之pandas统计分析
Mar 03 Python
Python验证码截取识别代码实例
May 16 Python
python属于哪种语言
Aug 16 Python
python把一个字符串切开的实例方法
Sep 27 Python
python对 MySQL 数据库进行增删改查的脚本
Oct 22 Python
分享提高 Python 代码的可读性的技巧
Mar 03 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
php 随机生成10位字符代码
2009/03/26 PHP
应用开发中涉及到的css和php笔记分享
2011/08/02 PHP
php使用function_exists判断函数可用的方法
2014/11/19 PHP
基于jquery的商品展示放大镜
2010/08/07 Javascript
javascript重写alert方法的实例代码
2013/03/29 Javascript
基于KMP算法JavaScript的实现方法分析
2013/05/03 Javascript
JavaScript实现在页面间传值的方法
2015/04/07 Javascript
javascript实现动态导入js与css等静态资源文件的方法
2015/07/25 Javascript
jQuery实现的登录浮动框效果代码
2015/09/26 Javascript
javascript实现获取浏览器版本、浏览器类型
2015/12/02 Javascript
利用jQuery.Validate异步验证用户名是否存在(推荐)
2016/12/09 Javascript
vue登录路由验证的实现
2017/12/13 Javascript
vue element table 表格请求后台排序的方法
2018/09/28 Javascript
vue随机验证码组件的封装实现
2020/02/19 Javascript
vue实现简单加法计算器
2020/10/22 Javascript
在antd中setFieldsValue和defaultVal的用法
2020/10/29 Javascript
Python ZipFile模块详解
2013/11/01 Python
python实现DNS正向查询、反向查询的例子
2014/04/25 Python
Python编程给numpy矩阵添加一列方法示例
2017/12/04 Python
解决python3爬虫无法显示中文的问题
2018/04/12 Python
Python3.7中安装openCV库的方法
2018/07/11 Python
解决python2 绘图title,xlabel,ylabel出现中文乱码的问题
2019/01/29 Python
利用django+wechat-python-sdk 创建微信服务器接入的方法
2019/02/20 Python
Python list列表中删除多个重复元素操作示例
2019/02/27 Python
在pytorch中实现只让指定变量向后传播梯度
2020/02/29 Python
追悼会上的答谢词
2014/01/10 职场文书
管理建议书范文
2014/05/13 职场文书
热门专业求职信
2014/05/24 职场文书
社区创先争优承诺书
2014/08/30 职场文书
公司向个人借款协议书范本
2014/10/09 职场文书
优秀班组申报材料
2014/12/25 职场文书
党员志愿者服务倡议书
2015/04/29 职场文书
2015年秋季运动会广播稿
2015/08/19 职场文书
导游词之江苏溱潼古镇
2019/11/27 职场文书
PO模式在selenium自动化测试框架的优势
2022/03/20 Python
ubuntu安装jupyter并设置远程访问的实现
2022/03/31 Python