pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅谈python函数之作用域(python3.5)
Oct 27 Python
对python opencv 添加文字 cv2.putText 的各参数介绍
Dec 05 Python
Python数据可视化库seaborn的使用总结
Jan 15 Python
Python实现二叉树的常见遍历操作总结【7种方法】
Mar 06 Python
使用Python检测文章抄袭及去重算法原理解析
Jun 14 Python
对Python函数设计规范详解
Jul 19 Python
Python (Win)readline和tab补全的安装方法
Aug 27 Python
python实现引用其他路径包里面的模块
Mar 09 Python
python logging.info在终端没输出的解决
May 12 Python
python支持多继承吗
Jun 19 Python
Keras 中Leaky ReLU等高级激活函数的用法
Jul 05 Python
Python使用paramiko连接远程服务器执行Shell命令的实现
Mar 04 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
跟我学小偷程序之成功偷取首页(第三天)
2006/10/09 PHP
ThinkPHP中实例Model方法的区别说明
2010/08/21 PHP
PHP PDOStatement::getAttribute讲解
2019/02/01 PHP
PHP实现通过文本文件统计页面访问量功能示例
2019/02/13 PHP
编写跨浏览器的javascript代码必备[js多浏览器兼容写法]
2008/10/29 Javascript
浅析jQuery的链式调用之each函数
2010/12/03 Javascript
TimergliderJS 一个基于jQuery的时间轴插件
2011/12/07 Javascript
Dom 学习总结以及实例的使用介绍
2013/04/24 Javascript
简单易用的倒计时js代码
2014/08/04 Javascript
JavaScript获取页面上被选中文字的方法技巧
2015/03/13 Javascript
详解页面滚动值scrollTop在FireFox与Chrome浏览器间的兼容问题
2015/12/03 Javascript
JS代码实现table数据分页效果
2016/05/26 Javascript
react-router中的属性详解
2017/06/01 Javascript
jQuery制作全屏宽度固定高度轮播图(实例讲解)
2017/07/08 jQuery
jQuery实现对网页节点的增删改查功能示例
2017/09/18 jQuery
基于vue打包后字体和图片资源失效问题的解决方法
2018/03/06 Javascript
Vue Cli3 创建项目的方法步骤
2018/10/15 Javascript
Vue 刷新当前路由的实现代码
2019/09/26 Javascript
python 多进程通信模块的简单实现
2014/02/20 Python
python队列queue模块详解
2018/04/27 Python
Python实现截取PDF文件中的几页代码实例
2019/03/11 Python
python调用pyaudio使用麦克风录制wav声音文件的教程
2019/06/26 Python
在python中将list分段并保存为array类型的方法
2019/07/15 Python
Python文件时间操作步骤代码详解
2020/04/13 Python
使用python实现微信小程序自动签到功能
2020/04/27 Python
python对 MySQL 数据库进行增删改查的脚本
2020/10/22 Python
Numpy数组的广播机制的实现
2020/11/03 Python
国际鲜花速递专家:Floraqueen
2016/11/24 全球购物
英国No.1体育用品零售商:SportsDirect.com
2019/10/16 全球购物
考试退步检讨书
2014/01/15 职场文书
残疾人小组计划书
2014/04/27 职场文书
2014年党的群众路线活动个人整改措施
2014/10/28 职场文书
2015政治思想表现评语
2015/03/25 职场文书
幼儿园卫生保健制度
2015/08/05 职场文书
2016年禁毒宣传活动总结
2016/04/05 职场文书
Java如何实现树的同构?
2021/06/22 Java/Android