pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python HTMLParser模块解析html获取url实例
Apr 08 Python
Python聚类算法之凝聚层次聚类实例分析
Nov 20 Python
Python线程池模块ThreadPoolExecutor用法分析
Dec 28 Python
python 获取微信好友列表的方法(微信web)
Feb 21 Python
django如何自己创建一个中间件
Jul 24 Python
Python日志syslog使用原理详解
Feb 18 Python
如何在python中执行另一个py文件
Apr 30 Python
python实现数字炸弹游戏
Jul 17 Python
浅谈Python xlwings 读取Excel文件的正确姿势
Feb 26 Python
用Python简陋模拟n阶魔方
Apr 17 Python
pytorch 如何使用float64训练
May 24 Python
Python数据可视化之用Matplotlib绘制常用图形
Jun 03 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
php递归实现无限分类生成下拉列表的函数
2010/08/08 PHP
php替换超长文本中的特殊字符的函数代码
2012/05/22 PHP
Yii框架批量插入数据扩展类的简单实现方法
2017/05/23 PHP
PHP错误提示It is not safe to rely on the system……的解决方法
2019/03/25 PHP
PHP实现一个限制实例化次数的类示例
2019/09/16 PHP
基于jquery的兼容各种浏览器的iframe自适应高度的脚本
2010/08/13 Javascript
JS的document.all函数使用示例
2013/12/30 Javascript
js获取字符串最后一位方法汇总
2014/11/13 Javascript
JavaScript中关键字 in 的使用方法详解
2016/10/17 Javascript
微信小程序 教程之注册页面
2016/10/17 Javascript
JS实现点击链接切换显示隐藏内容的方法
2017/10/19 Javascript
Webpack实战加载SVG的方法
2017/12/26 Javascript
angularjs 页面自适应高度的方法
2018/01/17 Javascript
微信小程序自定义组件封装及父子间组件传值的方法
2018/08/28 Javascript
vue移动端html5页面根据屏幕适配的四种解决方法
2018/10/19 Javascript
修改layui的后台模板的左侧导航栏可以伸缩的方法
2019/09/10 Javascript
node.js使用stream模块实现自定义流示例
2020/02/13 Javascript
javascript canvas API内容整理
2020/02/16 Javascript
vue中defineProperty和Proxy的区别详解
2020/11/30 Vue.js
[03:57]2016完美“圣”典风云人物:rOtk专访
2016/12/09 DOTA
结合Python的SimpleHTTPServer源码来解析socket通信
2016/06/27 Python
python 计算积分图和haar特征的实例代码
2019/11/20 Python
世界上最悠久的自行车制造商:Ribble Cycles
2017/03/18 全球购物
Schutz鞋官方网站:Schutz Shoes
2017/12/13 全球购物
Stubhub英国:购买体育、演唱会和剧院门票
2018/06/10 全球购物
写一个方法1000的阶乘
2012/11/21 面试题
感恩节红领巾广播稿
2014/02/11 职场文书
初中生期末评语大全
2014/04/24 职场文书
音乐教师求职信
2014/06/28 职场文书
珍惜资源的建议书
2014/08/26 职场文书
2014标准社保办理委托书
2014/10/06 职场文书
北京导游词
2015/02/12 职场文书
创业计划书之健康营养产业
2019/10/15 职场文书
Python爬虫:从m3u8文件里提取小视频的正确操作
2021/05/14 Python
JavaScript流程控制(分支)
2021/12/06 Javascript
Python读取和写入Excel数据
2022/04/20 Python