pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中无限元素列表的实现方法
Aug 18 Python
利用Python中的输入和输出功能进行读取和写入的教程
Apr 14 Python
深入理解python中的atexit模块
Mar 07 Python
Django框架登录加上验证码校验实现验证功能示例
May 23 Python
python async with和async for的使用
Jun 20 Python
python交互模式下输入换行/输入多行命令的方法
Jul 02 Python
Python中断多重循环的思路总结
Oct 04 Python
python实现滑雪游戏
Feb 22 Python
python类共享变量操作
Sep 03 Python
用Python 执行cmd命令
Dec 18 Python
python中用ggplot绘制画图实例讲解
Jan 26 Python
python如何将mat文件转为png
Jul 15 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
web server使用php生成web页面的三种方法总结
2013/10/28 PHP
php数据访问之增删改查操作
2016/05/09 PHP
Smarty模板引擎缓存机制详解
2016/05/23 PHP
Javascript SHA-1:Secure Hash Algorithm
2006/12/20 Javascript
多浏览器兼容的获取元素和鼠标的位置的js代码
2009/12/15 Javascript
JQuery操作表格(隔行着色,高亮显示,筛选数据)
2012/02/23 Javascript
13 款最热门的 jQuery 图像 360 度旋转插件推荐
2014/12/09 Javascript
JavaScript中通过提示框跳转页面的方法
2016/02/14 Javascript
js简单判断移动端系统的方法
2016/02/25 Javascript
如何高效率去掉js数组中的重复项
2016/04/12 Javascript
jquery常用的12个小功能
2016/07/22 Javascript
深入理解React中es6创建组件this的方法
2016/08/29 Javascript
原生js实现倒计时功能(多种格式调用)
2017/01/12 Javascript
vue生成token保存在客户端localStorage中的方法
2017/10/25 Javascript
js原生方法被覆盖,从新赋值原生的方法
2018/01/02 Javascript
JS实现提示框跟随鼠标移动
2019/08/27 Javascript
[08:08]DOTA2-DPC中国联赛2月28日Recap集锦
2021/03/11 DOTA
使用Python判断IP地址合法性的方法实例
2014/03/13 Python
python正则表达式match和search用法实例
2015/03/26 Python
python函数局部变量用法实例分析
2015/08/04 Python
详解Python中的静态方法与类成员方法
2017/02/28 Python
Python实现的密码强度检测器示例
2017/08/23 Python
Flask解决跨域的问题示例代码
2018/02/12 Python
pyqt5 lineEdit设置密码隐藏,删除lineEdit已输入的内容等属性方法
2019/06/24 Python
Python3 itchat实现微信定时发送群消息的实例代码
2019/07/12 Python
pytorch 中的重要模块化接口nn.Module的使用
2020/04/02 Python
python 画条形图(柱状图)实例
2020/04/24 Python
Python验证码截取识别代码实例
2020/05/16 Python
python 负数取模运算实例
2020/06/03 Python
Python如何对XML 解析
2020/06/28 Python
记一次Django响应超慢的解决过程
2020/09/17 Python
Django数据模型中on_delete使用详解
2020/11/30 Python
北美三大旅游网站之一:Travelocity
2017/08/12 全球购物
英国时尚优质的女装:Hope Fashion
2018/08/14 全球购物
大学本科生职业生涯规划书范文
2014/09/14 职场文书
微信小程序实现录音Record功能
2021/05/09 Javascript