pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析文件示例
Jan 23 Python
Python 字典dict使用介绍
Nov 30 Python
Python列表list内建函数用法实例分析【insert、remove、index、pop等】
Jul 24 Python
详解 Python中LEGB和闭包及装饰器
Aug 03 Python
Python实现正弦信号的时域波形和频谱图示例【基于matplotlib】
May 04 Python
详解python列表生成式和列表生成式器区别
Mar 27 Python
python pycharm的安装及其使用
Oct 11 Python
Python flask框架实现浏览器点击自定义跳转页面
Jun 04 Python
使用pytorch实现论文中的unet网络
Jun 24 Python
Python定时任务APScheduler安装及使用解析
Aug 07 Python
Python list和str互转的实现示例
Nov 16 Python
详解matplotlib绘图样式(style)初探
Feb 03 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
Eclipse中php插件安装及Xdebug配置的使用详解
2013/04/25 PHP
ThinkPHP框架实现的MySQL数据库备份功能示例
2018/05/24 PHP
PHP实现的多进程控制demo示例
2019/07/22 PHP
laravel 解决groupBy时出现的错误 isn't in Group By问题
2019/10/17 PHP
PHP实现随机发放扑克牌
2020/04/21 PHP
jquery ajax 检测用户注册时用户名是否存在
2009/11/03 Javascript
jQuery前台数据获取实现代码
2011/03/16 Javascript
腾讯UED 漂亮的提示信息效果代码
2011/09/12 Javascript
Javascript绝句欣赏 一些经典的js代码
2012/02/22 Javascript
jQuery实现的原图对比窗帘效果
2014/06/15 Javascript
javascript中加var和不加var的区别 你真的懂吗
2016/01/06 Javascript
JS获取url参数、主域名的方法实例分析
2016/08/03 Javascript
Angularjs中的页面访问权限怎么设置
2016/11/11 Javascript
详解npm 配置项registry修改为淘宝镜像
2018/09/07 Javascript
AngularJS 监听变量变化的实现方法
2018/10/09 Javascript
elementUI中Table表格问题的解决方法
2018/12/04 Javascript
Node.js中console.log()输出彩色字体的方法示例
2019/12/01 Javascript
nodejs使用Sequelize框架操作数据库的实现
2020/10/21 NodeJs
深入解析Python中的WSGI接口
2015/05/11 Python
Python利用turtle库绘制彩虹代码示例
2017/12/20 Python
Python数据预处理之数据规范化(归一化)示例
2019/01/08 Python
Python 定义只读属性的实现方式
2020/03/05 Python
python动态规划算法实例详解
2020/11/22 Python
Fashion Eyewear美国:英国线上设计师眼镜和太阳镜的零售商
2016/08/15 全球购物
.net开发工程师面试题
2014/02/25 面试题
创立科技Java面试题
2015/11/29 面试题
出口公司经理求职简历中的自我评价
2013/10/13 职场文书
干部培训自我鉴定
2014/01/22 职场文书
投资建议书模板
2014/05/12 职场文书
冰峪沟导游词
2015/02/09 职场文书
爱的教育读书笔记
2015/06/26 职场文书
七年级之家长会发言稿范文
2019/09/04 职场文书
导游词之山东红叶谷
2019/10/31 职场文书
导游词之西安大清真寺
2019/12/17 职场文书
selenium.webdriver中add_argument方法常用参数表
2021/04/08 Python
Python函数中的不定长参数相关知识总结
2021/06/24 Python