pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列表append和+的区别浅析
Feb 02 Python
python之PyMongo使用总结
May 26 Python
Python实现字符串反转的常用方法分析【4种方法】
Sep 30 Python
Python爬取当当、京东、亚马逊图书信息代码实例
Dec 09 Python
python 将字符串转换成字典dict的各种方式总结
Mar 23 Python
python操作xlsx文件的包openpyxl实例
May 03 Python
cmd运行python文件时对结果进行保存的方法
May 16 Python
python实现简易学生信息管理系统
Apr 05 Python
Django 实现对已存在的model进行更改
Mar 28 Python
python为什么要安装到c盘
Jul 20 Python
一些让Python代码简洁的实用技巧总结
Aug 23 Python
pandas中pd.groupby()的用法详解
Jun 16 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
用PHP和ACCESS写聊天室(八)
2006/10/09 PHP
php文件打包 下载之使用PHP自带的ZipArchive压缩文件并下载打包好的文件
2012/06/13 PHP
win7+apache+php+mysql环境配置操作详解
2013/06/10 PHP
在PHP中使用X-SendFile头让文件下载更快
2014/06/01 PHP
CI框架文件上传类及图像处理类用法分析
2016/05/18 PHP
php实现根据身份证获取精准年龄
2020/02/26 PHP
jQuery 入门级学习笔记及源码
2010/01/22 Javascript
jQuery EasyUI NumberBox(数字框)的用法
2010/07/08 Javascript
$.ajax返回的JSON无法执行success的解决方法
2011/09/09 Javascript
JavaScript高级程序设计(第3版)学习笔记9 js函数(下)
2012/10/11 Javascript
GRID拖拽行的实例代码
2013/07/18 Javascript
HTML页面弹出居中可拖拽的自定义窗口层
2014/05/07 Javascript
jquery幻灯片插件bxslider样式改进实例
2014/10/15 Javascript
jQuery学习笔记之jQuery中的$
2015/01/19 Javascript
Linux下编译安装php libevent扩展实例
2015/02/14 Javascript
jQuery实现按钮的点击 全选/反选 单选框/复选框 文本框 表单验证
2015/06/25 Javascript
jQuery入门基础知识学习指南
2015/08/14 Javascript
JS查找字符串中出现最多的字符及个数统计
2017/02/04 Javascript
基于 Vue 实现一个酷炫的 menu插件
2017/11/14 Javascript
js自定义trim函数实现删除两端空格功能
2018/02/09 Javascript
Vue.set()动态的新增与修改数据,触发视图更新的方法
2018/09/15 Javascript
JavaScript禁止右击保存图片,禁止拖拽图片的实现代码
2020/04/28 Javascript
[00:33]2016完美“圣”典风云人物:BurNIng宣传片
2016/12/10 DOTA
为什么说Python可以实现所有的算法
2019/10/04 Python
从零开始的TensorFlow+VScode开发环境搭建的步骤(图文)
2020/08/31 Python
英国知名美妆护肤在线商城:Zest Beauty
2018/04/24 全球购物
为娇小女性量身打造:Petite Studio
2018/11/01 全球购物
瑞典耳机品牌:URBANISTA
2019/12/03 全球购物
运动会闭幕式解说词
2014/02/21 职场文书
公司总经理工作职责管理办法
2014/02/28 职场文书
3.12植树节活动总结2014
2014/03/13 职场文书
《动手做做看》教学反思
2014/04/09 职场文书
家长通知书家长意见
2015/06/03 职场文书
学籍证明模板
2015/06/18 职场文书
导游词之太原天龙山
2020/01/02 职场文书
Python爬虫基础讲解之请求
2021/05/13 Python