pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Linux下Python获取IP地址的代码
Nov 30 Python
在Django中限制已登录用户的访问的方法
Jul 23 Python
Python爬取当当、京东、亚马逊图书信息代码实例
Dec 09 Python
浅谈python中对于json写入txt文件的编码问题
Jun 07 Python
Python 3.8中实现functools.cached_property功能
May 29 Python
利用python-docx模块写批量生日邀请函
Aug 26 Python
Python二次规划和线性规划使用实例
Dec 09 Python
Django权限设置及验证方式
May 13 Python
selenium判断元素是否存在的两种方法小结
Dec 07 Python
pytorch __init__、forward与__call__的用法小结
Feb 27 Python
如何使用Tkinter进行窗口的管理与设置
Jun 30 Python
Django框架中表单的用法
Jun 10 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
php下将XML转换为数组
2010/01/01 PHP
php数据类型判断函数有哪些
2013/09/23 PHP
PHP中ini_set与ini_get用法实例
2014/11/04 PHP
8个PHP数组面试题
2015/06/23 PHP
Laravel源码解析之路由的使用和示例详解
2018/09/27 PHP
PhpStorm2020 + phpstudyV8 +XDebug的教程详解
2020/09/17 PHP
基于jquery的复制网页内容到WORD的实现代码
2011/02/16 Javascript
jQuery Tools tab(幻灯片)
2012/07/14 Javascript
jQuery之排序组件的深入解析
2013/06/19 Javascript
javascript陷阱 一不小心你就中招了(字符运算)
2013/11/10 Javascript
js特殊字符过滤的示例代码
2014/03/05 Javascript
js获取元素外链样式的方法
2015/01/27 Javascript
详解JS中Array对象扩展与String对象扩展
2016/01/07 Javascript
深入理解jquery中的事件与动画
2016/05/24 Javascript
JavaScript中Array对象用法实例总结
2016/11/29 Javascript
JavaScript Canvas绘制圆形时钟效果
2020/08/20 Javascript
JS实现的简单表单验证功能示例
2017/10/13 Javascript
微信小程序支付PHP代码
2018/08/23 Javascript
微信小程序扫描二维码获取信息实例详解
2019/05/07 Javascript
Vue2.0 $set()的正确使用详解
2020/07/28 Javascript
新手该如何学python怎么学好python?
2008/10/07 Python
Python深入学习之闭包
2014/08/31 Python
Python Web框架Tornado运行和部署
2020/10/19 Python
Python自动化开发学习之三级菜单制作
2017/07/14 Python
Tensorflow模型实现预测或识别单张图片
2019/07/19 Python
python实现通过flask和前端进行数据收发
2019/08/22 Python
Python3中的f-Strings增强版字符串格式化方法
2020/03/04 Python
最新大学生自我评价
2013/09/24 职场文书
菜篮子工程实施方案
2014/03/08 职场文书
旅游节目策划方案
2014/05/26 职场文书
《中国梦我的梦》中学生演讲稿
2014/08/20 职场文书
医院见习报告范文
2014/11/03 职场文书
2015年度优秀员工自荐书
2015/03/06 职场文书
2015年护理工作总结范文
2015/04/03 职场文书
2015年大学生村官工作总结
2015/04/21 职场文书
幼儿园安全教育随笔
2015/08/14 职场文书