pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python学习之Anaconda的使用与配置方法
Jan 04 Python
Python3爬取英雄联盟英雄皮肤大图实例代码
Nov 14 Python
python正则-re的用法详解
Jul 28 Python
Django模型修改及数据迁移实现解析
Aug 01 Python
详细整理python 字符串(str)与列表(list)以及数组(array)之间的转换方法
Aug 30 Python
Python @property使用方法解析
Sep 17 Python
win10子系统python开发环境准备及kenlm和nltk的使用教程
Oct 14 Python
关于Python中定制类的比较运算实例
Dec 19 Python
python文件操作seek()偏移量,读取指正到指定位置操作
Jul 05 Python
Python 3.9的到来到底是意味着什么
Oct 14 Python
基于Python制作一副扑克牌过程详解
Oct 19 Python
Python如何使用神经网络进行简单文本分类
Feb 25 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
PHP的fsockopen、pfsockopen函数被主机商禁用的解决办法
2014/07/08 PHP
php使用pdo连接并查询sql数据库的方法
2014/12/24 PHP
PHP中使用Memache作为进程锁的操作类分享
2015/03/30 PHP
基于PHP实现的多元线性回归模拟曲线算法
2018/01/30 PHP
PHP封装curl的调用接口及常用函数详解
2018/05/31 PHP
javascript Split方法,indexOf方法、lastIndexOf 方法和substring 方法
2009/03/21 Javascript
jquery实现图片滚动效果的简单实例
2013/11/23 Javascript
JQuery实现鼠标移动到图片上显示边框效果
2014/01/09 Javascript
jquery禁止输入数字以外的字符的示例(纯数字验证码)
2014/04/10 Javascript
原生js仿jq判断当前浏览器是否为ie,精确到ie6~8
2014/08/30 Javascript
JavaScript多并发问题如何处理
2015/10/28 Javascript
每日十条JavaScript经验技巧(二)
2016/06/23 Javascript
JS遍历对象属性的方法示例
2017/01/10 Javascript
微信小程序 本地数据读取实例
2017/04/27 Javascript
微信小程序学习笔记之获取位置信息操作图文详解
2019/03/29 Javascript
Vue打包后访问静态资源路径问题
2019/11/08 Javascript
Python中优化NumPy包使用性能的教程
2015/04/23 Python
在 Python 应用中使用 MongoDB的方法
2017/01/05 Python
Linux RedHat下安装Python2.7开发环境
2017/05/20 Python
Python实现的十进制小数与二进制小数相互转换功能
2017/10/12 Python
python try except 捕获所有异常的实例
2018/10/18 Python
Python+Selenium使用Page Object实现页面自动化测试
2019/07/14 Python
Python 多线程,threading模块,创建子线程的两种方式示例
2019/09/29 Python
matplotlib.pyplot画图并导出保存的实例
2019/12/07 Python
python socket通信编程实现文件上传代码实例
2019/12/14 Python
keras获得某一层或者某层权重的输出实例
2020/01/24 Python
python中tab键是什么意思
2020/06/18 Python
pycharm配置python 设置pip安装源为豆瓣源
2021/02/05 Python
Python的Tqdm模块实现进度条配置
2021/02/24 Python
Kangol帽子官网:坎戈尔袋鼠
2018/09/26 全球购物
日本最大化妆品和美容产品的综合口碑网站:cosme shopping
2019/08/28 全球购物
迟到检讨书300字
2014/02/14 职场文书
入党政审材料范文
2014/12/24 职场文书
钱学森观后感
2015/06/04 职场文书
幼儿园开学家长寄语(2016秋季)
2015/12/03 职场文书
教你使用Pandas直接核算Excel中快递费用
2021/05/12 Python