pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随机生成数模块random使用实例
Apr 13 Python
python中随机函数random用法实例
Apr 30 Python
利用python实现简单的循环购物车功能示例代码
Jul 05 Python
Python+PyQT5的子线程更新UI界面的实例
Jun 14 Python
python内存管理机制原理详解
Aug 12 Python
python的scipy实现插值的示例代码
Nov 12 Python
python cv2截取不规则区域图片实例
Dec 21 Python
Python通过正则库爬取淘宝商品信息代码实例
Mar 02 Python
django 连接数据库出现1045错误的解决方式
May 14 Python
利用Python实现Json序列化库的方法步骤
Sep 09 Python
Python 删除List元素的三种方法remove、pop、del
Nov 16 Python
Opencv python 图片生成视频的方法示例
Nov 18 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
下载文件的点击数回填
2006/10/09 PHP
php 采集书并合成txt格式的实现代码
2009/03/01 PHP
优化PHP程序的方法小结
2012/02/23 PHP
使用php get_headers 判断URL是否有效的解决办法
2013/04/27 PHP
非常重要的php正则表达式详解
2016/01/04 PHP
yii2中使用Active Record模式的方法
2016/01/09 PHP
php 实现一个字符串加密解密的函数实例代码
2016/11/01 PHP
一个符号插入器 中用到的js代码
2007/09/04 Javascript
Javascript的getYear、getFullYear、getUTCFullYear异同分享
2011/11/30 Javascript
JQuery中如何传递参数如click(),change()等具体实现
2013/04/28 Javascript
JavaScript语言核心数据类型和变量使用介绍
2013/08/23 Javascript
Jqgrid表格随窗口大小改变而改变的简单实例
2013/12/28 Javascript
javascript判断是手机还是电脑访问网页的简单实例分享
2014/06/03 Javascript
jQuery中extend函数详解
2015/07/13 Javascript
js事件驱动机制 浏览器兼容处理方法
2016/07/23 Javascript
使用BootStrap建立响应式网页——通栏轮播图(carousel)
2016/12/21 Javascript
node.js的事件机制
2017/02/08 Javascript
jquery实现数字输入框
2017/02/22 Javascript
详细讲解如何创建, 发布自己的 Vue UI 组件库
2019/05/29 Javascript
jQuery 筛选器简单操作示例
2019/10/02 jQuery
python切换hosts文件代码示例
2013/12/31 Python
Python中的深拷贝和浅拷贝详解
2015/06/03 Python
Python卸载模块的方法汇总
2016/06/07 Python
scrapy-redis的安装部署步骤讲解
2019/02/27 Python
python经典趣味24点游戏程序设计
2019/07/26 Python
python 实现矩阵按对角线打印
2019/11/29 Python
Big Green Smile法国:领先的英国有机和天然产品在线商店
2021/01/02 全球购物
长曲棍球装备:Lacrosse Monkey
2020/12/02 全球购物
介绍一下mysql的日期和时间函数
2013/03/28 面试题
理工科学生的自我评价
2013/12/15 职场文书
医院保洁服务方案
2014/06/11 职场文书
商场消防安全责任书
2014/07/29 职场文书
2014年度思想工作总结
2014/11/27 职场文书
2014年银行年终工作总结
2014/12/19 职场文书
实习工作表现评语
2014/12/31 职场文书
委托书格式要求
2015/01/28 职场文书