pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python3模拟百度登录并实现百度贴吧签到示例分享(百度贴吧自动签到)
Feb 24 Python
详解Python编程中基本的数学计算使用
Feb 04 Python
python cx_Oracle模块的安装和使用详细介绍
Feb 13 Python
python list元素为tuple时的排序方法
Apr 18 Python
如何实现删除numpy.array中的行或列
May 08 Python
利用python循环创建多个文件的方法
Oct 25 Python
使用Python-OpenCV向图片添加噪声的实现(高斯噪声、椒盐噪声)
May 28 Python
python matplotlib拟合直线的实现
Nov 19 Python
python多进程并发demo实例解析
Dec 13 Python
Python实现进度条和时间预估的示例代码
Jun 02 Python
解决keras模型保存h5文件提示无此目录问题
Jul 01 Python
通俗易懂了解Python装饰器原理
Sep 17 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
PHP 基本语法格式
2009/12/15 PHP
使用php测试硬盘写入速度示例
2014/01/27 PHP
thinkphp实现图片上传功能分享
2014/03/04 PHP
PHP框架性能测试报告
2016/05/08 PHP
php计数排序算法的实现代码(附四个实例代码)
2020/03/31 PHP
jquery 新浪网易的评论块制作
2010/07/01 Javascript
jQery使网页在显示器上居中显示适用于任何分辨率
2014/06/09 Javascript
Nodejs极简入门教程(一):模块机制
2014/10/25 NodeJs
EasyUI,点击开启编辑框,并且编辑框获得焦点的方法
2015/03/01 Javascript
通过jquery-ui中的sortable来实现拖拽排序的简单实例
2016/05/24 Javascript
修改Jquery Dialog 位置的实现方法
2016/08/26 Javascript
jQuery中clone()函数实现表单中增加和减少输入项
2017/05/13 jQuery
react-native-tab-navigator组件的基本使用示例代码
2017/09/07 Javascript
深入浅析Vue中的Prop
2018/06/10 Javascript
微信小程序实现复选框效果
2018/12/28 Javascript
微信小程序实现的一键拨号功能示例
2019/04/24 Javascript
JavaScript对象原型链原理解析
2020/01/22 Javascript
一行JavaScript代码如何实现瀑布流布局
2020/12/11 Javascript
详解python使用Nginx和uWSGI来运行Python应用
2018/01/09 Python
Python 一句话生成字母表的方法
2019/01/02 Python
python常用库之NumPy和sklearn入门
2019/07/11 Python
Python递归求出列表(包括列表中的子列表)的最大值实例
2020/02/27 Python
Python3 mmap内存映射文件示例解析
2020/03/23 Python
使用Keras预训练模型ResNet50进行图像分类方式
2020/05/23 Python
使用Python实现微信拍一拍功能的思路代码
2020/07/09 Python
销售行政专员岗位职责
2014/06/10 职场文书
车贷收入证明范本
2014/09/14 职场文书
国庆节标语大全
2014/10/08 职场文书
施工单位工程部经理岗位职责
2015/04/09 职场文书
党员干部廉洁自律承诺书
2015/04/28 职场文书
现货白银电话营销话术
2015/05/29 职场文书
建国大业观后感
2015/06/01 职场文书
HTML基础-标签分类(闭合标签,空标签,块级元素,行内元素,行级块元素,可替换元素)
2021/03/31 HTML / CSS
MATLAB 如何求取离散点的曲率最大值
2021/04/16 Python
Python超简单容易上手的画图工具库推荐
2021/05/10 Python
解决mysql的int型主键自增问题
2021/07/15 MySQL