pytorch自定义二值化网络层方式


Posted in Python onJanuary 07, 2020

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python操作Mysql实例代码教程在线版(查询手册)
Feb 18 Python
用Python实现QQ游戏大家来找茬辅助工具
Sep 14 Python
python判断字符串是否包含子字符串的方法
Mar 24 Python
Python根据指定日期计算后n天,前n天是哪一天的方法
May 29 Python
python 读写文件,按行修改文件的方法
Jul 12 Python
浅析Python 中几种字符串格式化方法及其比较
Jul 02 Python
浅谈pytorch、cuda、python的版本对齐问题
Jan 15 Python
python对数组进行排序,并输出排序后对应的索引值方式
Feb 28 Python
pycharm开发一个简单界面和通用mvc模板(操作方法图解)
May 27 Python
django 将自带的数据库sqlite3改成mysql实例
Jul 09 Python
利用python制作拼图小游戏的全过程
Dec 04 Python
python 使用tkinter与messagebox写界面和弹窗
Mar 20 Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
python__new__内置静态方法使用解析
Jan 07 #Python
Python常用模块sys,os,time,random功能与用法实例分析
Jan 07 #Python
You might like
扩展你的 PHP 之入门篇
2006/12/04 PHP
php类自动装载、链式操作、魔术方法实现代码
2017/07/23 PHP
PHP数组遍历的几种常见方式总结
2019/02/15 PHP
获取URL地址中的文件名和参数的javascript代码
2009/09/02 Javascript
javascript或asp实现的判断身份证号码是否正确两种验证方法
2009/11/26 Javascript
javascript类型转换示例
2014/04/29 Javascript
轻松创建nodejs服务器(4):路由
2014/12/18 NodeJs
javascript包装对象实例分析
2015/03/27 Javascript
JavaSciprt中处理字符串之sup()方法的使用教程
2015/06/08 Javascript
直接拿来用的15个jQuery代码片段
2015/09/23 Javascript
浅谈angularJS中的事件
2016/07/12 Javascript
js实现淡入淡出轮播切换功能
2017/01/13 Javascript
关于AngularJs数据的本地存储详解
2017/01/20 Javascript
Angular.js组件之input mask对input输入进行格式化详解
2017/07/10 Javascript
EasyUI在Panel上动态添加LinkButton按钮
2017/08/11 Javascript
使用 Vue 绑定单个或多个 Class 名的实例代码
2018/01/08 Javascript
node.js监听文件变化的实现方法
2019/04/17 Javascript
python 中的divmod数字处理函数浅析
2017/10/17 Python
Python基于FTP模块实现ftp文件上传操作示例
2018/04/23 Python
Python3实现取图片中特定的像素替换指定的颜色示例
2019/01/24 Python
python删除列表元素的三种方法(remove,pop,del)
2019/07/22 Python
Python for循环与getitem的关系详解
2020/01/02 Python
在 Windows 下搭建高效的 django 开发环境的详细教程
2020/07/27 Python
Python通用唯一标识符uuid模块使用案例
2020/09/10 Python
python time()的实例用法
2020/11/03 Python
Python 实现图片转字符画的示例(静态图片,gif皆可)
2020/11/05 Python
美国紧身牛仔裤品牌:NYDJ
2017/05/24 全球购物
C语言怎样定义和声明全局变量和函数最好
2013/11/26 面试题
创建精神文明单位实施方案
2014/03/08 职场文书
商铺消防安全责任书
2014/07/29 职场文书
群众路线班子对照检查材料
2014/09/25 职场文书
节水倡议书
2015/01/19 职场文书
入党函调证明材料
2015/06/19 职场文书
学会掌握自己命运的十条黄金法则:
2019/08/08 职场文书
导游词之南京中山陵
2019/11/27 职场文书
Win11更新失败并提示0xc1900101
2022/04/19 数码科技