pytorch中的自定义反向传播,求导实例


Posted in Python onJanuary 06, 2020

pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。

那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`

import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
  x_abs = np.abs(x)
  if x_abs < 1 and x_abs >= 0:
    y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
  elif x_abs > 1 and x_abs < 2:
    y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
  else:
    y = 0
  return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
  # data_in = data_in.detach().numpy()
  self.grad = np.zeros(data_in.shape,dtype=np.float32)
  obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
  data_tmp = data_in.copy()
  data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
  data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
  print(data_tmp.shape)
  for axis0 in range(obj_shape[0]):
    f_0 = float(axis0) / scale - np.floor(axis0 / scale)
    int_0 = int(axis0 / scale) + 2
    axis0_weight = np.array(
      [[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
    for axis1 in range(obj_shape[1]):
      f_1 = float(axis1) / scale - np.floor(axis1 / scale)
      int_1 = int(axis1 / scale) + 2
      axis1_weight = np.array(
        [[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
      nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
      grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
      for i in range(4):
        for j in range(4):
          nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
          for ii in range(data_in.shape[2]):
            self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
      tmp = np.matmul(axis0_weight, nbr_pixel)
      data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
      # img = np.transpose(img[0, :, :, :], [1, 2, 0])
  return data_obj

def forward(self,input):
  print(type(input))
  input_ = input.detach().numpy()
  output = self.bicubic_interpolate(input_)
  # return input.new(output)
  return torch.Tensor(output)

def backward(self,grad_output):
  print(self.grad.shape,grad_output.shape)
  grad_output.detach().numpy()
  grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
  for i in range(self.grad.shape[0]):
    for j in range(self.grad.shape[1]):
      grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
  grad_input = grad_output_tmp*self.grad
  print(type(grad_input))
  # return grad_output.new(grad_input)
  return torch.Tensor(grad_input)

def bicubic(input):
return Bicubic()(input)

def main():
	hr = Image.open('./baboon/baboon_hr.png').convert('L')
	hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
	hr.requires_grad = True
	lr = bicubic(hr)
	print(lr.is_leaf)
	loss=torch.mean(lr)
	loss.backward()
if __name__ =='__main__':
	main()

要想实现自动求导,必须同时实现forward(),backward()两个函数。

1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。

2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数

以上这篇pytorch中的自定义反向传播,求导实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python每隔N秒运行指定函数的方法
Mar 16 Python
Collatz 序列、逗号代码、字符图网格实例
Jun 22 Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 Python
对python多线程与global变量详解
Nov 09 Python
Python面向对象程序设计OOP入门教程【类,实例,继承,重载等】
Jan 05 Python
Python3中函数参数传递方式实例详解
May 05 Python
python如何实现代码检查
Jun 28 Python
使用PYTHON解析Wireshark的PCAP文件方法
Jul 23 Python
PyTorch预训练的实现
Sep 18 Python
你还在@微信官方?聊聊Python生成你想要的微信头像
Sep 25 Python
windows python3安装Jupyter Notebooks教程
Apr 13 Python
Django实现聊天机器人
May 31 Python
PyTorch中 tensor.detach() 和 tensor.data 的区别详解
Jan 06 #Python
6行Python代码实现进度条效果(Progress、tqdm、alive-progress​​​​​​​和PySimpleGUI库)
Jan 06 #Python
基于python+selenium的二次封装的实现
Jan 06 #Python
Python使用Tkinter实现滚动抽奖器效果
Jan 06 #Python
Python使用Tkinter实现转盘抽奖器的步骤详解
Jan 06 #Python
pytorch 实现在预训练模型的 input上增减通道
Jan 06 #Python
Python 将json序列化后的字符串转换成字典(推荐)
Jan 06 #Python
You might like
星际争霸 Starcraft 编年史
2020/03/14 星际争霸
使用淘宝IP库获取用户ip地理位置
2013/10/27 PHP
php常量详细解析
2015/10/27 PHP
详解在YII2框架中使用UEditor编辑器发布文章
2018/11/02 PHP
PHP利用DWZ.CN服务生成短网址
2019/08/11 PHP
PHP接入支付宝接口失效流程详解
2020/11/10 PHP
Yii 实现数据加密和解密
2021/03/09 PHP
javascript showModalDialog模态对话框使用说明
2009/12/31 Javascript
js下用gb2312编码解码实现方法
2009/12/31 Javascript
Array.prototype 的泛型应用分析
2010/04/30 Javascript
仅Firefox中链接A无法实现模拟点击以触发其默认行为
2011/07/31 Javascript
js实现带搜索功能的下拉框实时搜索实时匹配
2013/11/05 Javascript
深入理解Javascript动态方法调用与参数修改的问题
2013/12/10 Javascript
input:checkbox多选框实现单选效果跟radio一样
2014/06/16 Javascript
jQuery简易时光轴实现方法示例
2017/03/13 Javascript
微信小程序MUI侧滑导航菜单示例(Popup弹出式,左侧滑动,右侧不动)
2019/01/23 Javascript
微信小程序点击图片实现长按预览、保存、识别带参数二维码、转发等功能
2019/07/20 Javascript
Moment.js实现多个同时倒计时
2019/08/26 Javascript
vue根据条件不同显示不同按钮的操作
2020/08/04 Javascript
pycharm 使用心得(五)断点调试
2014/06/06 Python
浅析python实现scrapy定时执行爬虫
2018/03/04 Python
Django基础知识 web框架的本质详解
2019/07/18 Python
python代码实现逻辑回归logistic原理
2019/08/07 Python
python suds访问webservice服务实现
2020/06/26 Python
Python3使用 GitLab API 进行批量合并分支
2020/10/15 Python
中英双版中文教师求职信
2013/10/27 职场文书
投标邀请书范文
2014/01/31 职场文书
调解员先进事迹材料
2014/02/07 职场文书
班级团队活动方案
2014/08/14 职场文书
学校领导四风问题整改措施思想汇报
2014/10/09 职场文书
党的群众路线教育实践活动学习笔记范文
2014/11/06 职场文书
长城英文导游词
2015/01/30 职场文书
react中的DOM操作实现
2021/06/30 Javascript
试了下Golang实现try catch的方法
2021/07/01 Golang
详解JS数组方法
2021/11/20 Javascript
MySQL如何修改字段类型和字段长度
2022/06/10 MySQL