pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中unittest模块做UT(单元测试)使用实例
Jun 12 Python
利用Python开发微信支付的注意事项
Aug 19 Python
浅谈Django REST Framework限速
Dec 12 Python
python_opencv用线段画封闭矩形的实例
Dec 05 Python
对python dataframe逻辑取值的方法详解
Jan 30 Python
python使用selenium实现批量文件下载
Mar 11 Python
python实现字符串加密成纯数字
Mar 19 Python
用python建立两个Y轴的XY曲线图方法
Jul 08 Python
pd.DataFrame统计各列数值多少的实例
Dec 05 Python
Pytorch实现各种2d卷积示例
Dec 30 Python
python中return不返回值的问题解析
Jul 22 Python
python3中apply函数和lambda函数的使用详解
Feb 28 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
一个简单计数器的源代码
2006/10/09 PHP
自定义php类(查找/修改)xml文档
2013/03/26 PHP
js实现简单模态窗口,背景灰显
2008/11/14 Javascript
扩展JS Date对象时间格式化功能的小例子
2013/12/02 Javascript
jQuery实现预加载图片的方法
2015/03/17 Javascript
jQuery判断元素上是否绑定了指定事件的方法
2015/03/17 Javascript
jQuery实现的fixedMenu下拉菜单效果代码
2015/08/24 Javascript
jquery实现的横向二级导航效果代码
2015/08/26 Javascript
js实现表单及时验证功能 用户信息立即验证
2016/09/13 Javascript
JS点击某个图标或按钮弹出文件选择框的实现代码
2016/09/27 Javascript
RequireJS 依赖关系的实例(推荐)
2017/01/21 Javascript
Vue利用路由钩子token过期后跳转到登录页的实例
2017/10/26 Javascript
使用vue-route 的 beforeEach 实现导航守卫(路由跳转前验证登录)功能
2018/03/22 Javascript
使用webpack-dev-server处理跨域请求的方法
2018/04/18 Javascript
基于vue写一个全局Message组件的实现
2019/08/15 Javascript
VUE的history模式下除了index外其他路由404报错解决办法
2019/08/21 Javascript
原生js实现拖拽移动与缩放效果
2020/08/24 Javascript
jQuery实现增删改查
2020/12/22 jQuery
node脚手架搭建服务器实现token验证的方法
2021/01/20 Javascript
[01:00:11]DOTA2-DPC中国联赛 正赛 CDEC vs DLG BO3 第一场 2月7日
2021/03/11 DOTA
[01:32:22]DOTA2-DPC中国联赛 正赛 Ehome vs VG BO3 第一场 2月5日
2021/03/11 DOTA
Python实现快速排序和插入排序算法及自定义排序的示例
2016/02/16 Python
python制作企业邮箱的爆破脚本
2016/10/05 Python
Django管理员账号和密码忘记的完美解决方法
2018/12/06 Python
python 同时读取多个文件的例子
2019/07/16 Python
PyCharm无法识别PyQt5的2种解决方法,ModuleNotFoundError: No module named 'pyqt5'
2020/02/17 Python
CSS3中Transform动画属性用法详解
2016/07/04 HTML / CSS
使用phonegap进行提示操作的具体方法
2017/03/30 HTML / CSS
Oracle里面常用的数据字典有哪些
2014/02/14 面试题
最新远光软件笔试题面试题内容
2013/11/08 面试题
上班迟到检讨书
2014/01/10 职场文书
小学庆六一活动总结
2014/08/28 职场文书
2016年安全月活动总结
2016/04/06 职场文书
Redis遍历所有key的两个命令(KEYS 和 SCAN)
2021/04/12 Redis
MIME类型中application/xml与text/xml的区别介绍
2022/01/18 HTML / CSS
科学家研发出新型速效酶,可在 24 小时内降解塑料制品
2022/04/29 数码科技