tensorflow自定义激活函数实例


Posted in Python onFebruary 04, 2020

前言:因为研究工作的需要,要更改激活函数以适应自己的网络模型,但是单纯的函数替换会训练导致不能收敛。这里还有些不清楚为什么,希望有人可以给出解释。查了一些博客,发现了解决之道。下面将解决过程贴出来供大家指正。

1.背景

之前听某位老师提到说tensorflow可以在不给梯度函数的基础上做梯度下降,所以尝试了替换。我的例子时将ReLU改为平方。即原来的激活函数是 tensorflow自定义激活函数实例 现在换成 tensorflow自定义激活函数实例

单纯替换激活函数并不能较好的效果,在我的实验中,迭代到一定批次,准确率就会下降,最终降为10%左右保持稳定。而事实上,这中间最好的训练精度为92%。资源有限,问了对神经网络颇有研究的同学,说是激活函数的问题,然而某篇很厉害的论文中提到其精度在99%,着实有意思。之后开始研究自己些梯度函数以完成训练。

2.大概流程

首先要确定梯度函数,之后将其处理为tf能接受的类型。

2.1定义自己的激活函数

def square(x):
 return pow(x, 2)

2.2 定义该激活函数的一次梯度函数

def square_grad(x):
 return 2 * x

2.3 让numpy数组每一个元素都能应用该函数(全局)

square_np = np.vectorize(square)
square_grad_np = np.vectorize(square_grad)

2.4 转为tf可用的32位float型,numpy默认是64位(全局)

square_np_32 = lambda x: square_np(x).astype(np.float32)
square_grad_np_32 = lambda x: square_grad_np(x).astype(np.float32)

2.5 定义tf版的梯度函数

def square_grad_tf(x, name=None):
 with ops.name_scope(name, "square_grad_tf", [x]) as name:
 y = tf.py_func(square_grad_np_32, [x], [tf.float32], name=name, stateful=False)
 return y[0]

2.6 定义函数

def my_py_func(func, inp, Tout, stateful=False, name=None, my_grad_func=None):
 # need to generate a unique name to avoid duplicates:
 random_name = "PyFuncGrad" + str(np.random.randint(0, 1E+8))
 tf.RegisterGradient(random_name)(my_grad_func)
 g = tf.get_default_graph()
 with g.gradient_override_map({"PyFunc": random_name, "PyFuncStateless": random_name}):
 return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

2.7 定义梯度,该函数依靠上一个函数my_py_func计算并传播

def _square_grad(op, pred_grad):
 x = op.inputs[0]
 cur_grad = square_grad(x)
 next_grad = pred_grad * cur_grad
 return next_grad

2.8 定义tf版的square函数

def square_tf(x, name=None):
 with ops.name_scope(name, "square_tf", [x]) as name:
 y = my_py_func(square_np_32,
   [x],
   [tf.float32],
   stateful=False,
   name=name,
   my_grad_func=_square_grad)
 return y[0]

3.使用

跟用其他激活函数一样,直接用就行了。input_data:输入数据。

h = square_tf(input_data)

over. 学艺不精,多多指教!

以上这篇tensorflow自定义激活函数实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的闭包详细介绍和实例
Nov 21 Python
浅谈django中的认证与登录
Oct 31 Python
使用Python进行AES加密和解密的示例代码
Feb 02 Python
详解django实现自定义manage命令的扩展
Aug 13 Python
浅析pandas 数据结构中的DataFrame
Oct 12 Python
pygame实现五子棋游戏
Oct 29 Python
基于Python+Appium实现京东双十一自动领金币功能
Oct 31 Python
Pycharm创建项目时如何自动添加头部信息
Nov 14 Python
使用opencv将视频帧转成图片输出
Dec 10 Python
python爬虫添加请求头代码实例
Dec 28 Python
opencv python 对指针仪表读数识别的两种方式
Jan 14 Python
VSCODE配置Markdown及Markdown基础语法详解
Jan 19 Python
pytorch对梯度进行可视化进行梯度检查教程
Feb 04 #Python
pytorch梯度剪裁方式
Feb 04 #Python
基于梯度爆炸的解决方法:clip gradient
Feb 04 #Python
Python 格式化输出_String Formatting_控制小数点位数的实例详解
Feb 04 #Python
python求一个字符串的所有排列的实现方法
Feb 04 #Python
Windows上安装tensorflow  详细教程(图文详解)
Feb 04 #Python
有关Tensorflow梯度下降常用的优化方法分享
Feb 04 #Python
You might like
经典的星际争霸,满是回忆的BGM
2020/04/09 星际争霸
让Nginx支持ThinkPHP的URL重写和PATHINFO的方法分享
2011/08/08 PHP
对淘宝URL中ID提取的PHP代码
2013/09/01 PHP
小型js框架veryide.librar源代码
2009/03/05 Javascript
基于jquery的获取mouse坐标插件的实现代码
2010/04/01 Javascript
解析javascript 数组以及json元素的添加删除
2013/06/26 Javascript
DOM节点深度克隆函数cloneNode()用法实例
2015/01/12 Javascript
浅谈JavaScript的事件
2015/02/27 Javascript
使用javascript实现判断当前浏览器
2015/04/14 Javascript
JS实现的打字机效果完整实例
2016/06/20 Javascript
AngularJs Injecting Services Into Controllers详解
2016/09/02 Javascript
JQuery实现DIV其他动画效果的简单实例
2016/09/18 Javascript
微信小程序 九宫格实例代码
2017/01/21 Javascript
JS实现前端页面的搜索功能
2018/06/12 Javascript
微信小程序生成海报分享朋友圈的实现方法
2019/05/06 Javascript
JavaScript实现网页动态生成表格
2020/11/25 Javascript
Python 流程控制实例代码
2009/09/25 Python
Python中实现switch功能实例解析
2018/01/11 Python
对Tensorflow中的矩阵运算函数详解
2018/07/27 Python
python使用装饰器作日志处理的方法
2019/07/11 Python
django基于cors解决跨域请求问题详解
2019/08/06 Python
python爬虫模拟浏览器访问-User-Agent过程解析
2019/12/28 Python
pytorch实现onehot编码转为普通label标签
2020/01/02 Python
python openCV自制绘画板
2020/10/27 Python
jupyter notebook更换皮肤主题的实现
2021/01/07 Python
使用Python制作一盏 3D 花灯喜迎元宵佳节
2021/02/26 Python
美国祛痘、抗衰老药妆品牌:Murad
2016/08/27 全球购物
database面试题
2013/03/28 面试题
2013年高中生自我评价
2013/10/23 职场文书
国庆促销活动总结
2014/08/29 职场文书
运动会入场词
2015/07/18 职场文书
《中国机长》观后感:敬畏生命,敬畏职责
2019/11/12 职场文书
python plt.plot bar 如何设置绘图尺寸大小
2021/06/01 Python
mysql中between的边界,范围说明
2021/06/08 MySQL
基于Python实现将列表数据生成折线图
2022/03/23 Python
vue cli4中mockjs在dev环境和build环境的配置详情
2022/04/06 Vue.js