pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 3实战爬虫之爬取京东图书的图片详解
Oct 09 Python
Python中max函数用于二维列表的实例
Apr 03 Python
python 字符串和整数的转换方法
Jun 25 Python
Django+JS 实现点击头像即可更改头像的方法示例
Dec 26 Python
Python3数字求和的实例
Feb 19 Python
numpy下的flatten()函数用法详解
May 27 Python
PyCharm+Qt Designer+PyUIC安装配置教程详解
Jun 13 Python
numpy数组广播的机制
Jul 12 Python
Python调用.NET库的方法步骤
Dec 27 Python
深入了解如何基于Python读写Kafka
Dec 31 Python
Python3自定义json逐层解析器代码
May 11 Python
python3字符串输出常见面试题总结
Dec 01 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
用PHP将数据导入到Foxmail
2006/10/09 PHP
php mysql 判断update之后是否更新了的方法
2012/01/10 PHP
PHP中的Memcache详解
2014/04/05 PHP
php使用cookie保存登录用户名的方法
2015/01/26 PHP
php中访问修饰符的知识点总结
2019/01/27 PHP
Yii框架视图、视图布局、视图数据块操作示例
2019/10/14 PHP
jQuery 表单验证插件formValidation实现个性化错误提示
2009/06/23 Javascript
基于jQuery的的一个隔行变色,鼠标移动变色的小插件
2010/07/06 Javascript
深入理解JavaScript系列(3) 全面解析Module模式
2012/01/15 Javascript
JS 去前后空格大全(IE9亲测)
2013/07/15 Javascript
鼠标移到div,浮层显示明细,弹出层与div的上边距左边距重合(示例代码)
2013/12/14 Javascript
jquery实现的简单二级菜单效果代码
2015/09/22 Javascript
javascript实现Email邮件显示与删除功能
2015/11/21 Javascript
js组件SlotMachine实现图片切换效果制作抽奖系统
2016/04/17 Javascript
以WordPress为例讲解jQuery美化页面Title的方法
2016/05/23 Javascript
Angular页面间切换及传值的4种方法
2016/11/04 Javascript
Web前端框架bootstrap实战【第一次接触使用】
2016/12/28 Javascript
JavaScript fetch接口案例解析
2018/08/30 Javascript
vue单页缓存存在的问题及解决方案(小结)
2018/09/25 Javascript
浅谈angular2子组件的事件传递(任意组件事件传递)
2018/09/30 Javascript
浅析vue-router实现原理及两种模式
2020/02/11 Javascript
利用原生JS实现欢乐水果机小游戏
2020/04/23 Javascript
vue keep-alive实现多组件嵌套中个别组件存活不销毁的操作
2020/10/30 Javascript
[38:27]完美世界DOTA2联赛PWL S2 Forest vs FTD.C 第二场 11.26
2020/11/30 DOTA
Python使用Flask框架同时上传多个文件的方法
2015/03/21 Python
详解pandas删除缺失数据(pd.dropna()方法)
2019/06/25 Python
Python实现Mysql数据统计及numpy统计函数
2019/07/15 Python
使用Python操作ArangoDB的方法步骤
2020/02/02 Python
PHP面试题及答案一
2012/06/18 面试题
少先队入队活动方案
2014/02/08 职场文书
驻村工作先进事迹
2014/08/14 职场文书
旅行社优秀创业计划书
2014/08/16 职场文书
商场圣诞节活动总结
2015/05/06 职场文书
创业计划书之美甲店
2019/09/20 职场文书
PHP判断是否是json字符串
2021/04/01 PHP
Go语言基础知识点介绍
2021/07/04 Golang