如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取从命令行输入数字的方法
Apr 29 Python
python使用socket远程连接错误处理方法
Apr 29 Python
Python常用时间操作总结【取得当前时间、时间函数、应用等】
May 11 Python
Python最小二乘法矩阵
Jan 02 Python
详解Python匿名函数(lambda函数)
Apr 19 Python
int在python中的含义以及用法
Jun 27 Python
用Python+OpenCV对比图像质量的几种方法
Jul 15 Python
Django2 连接MySQL及model测试实例分析
Dec 10 Python
Python @property装饰器原理解析
Jan 22 Python
使用python从三个角度解决josephus问题的方法
Mar 27 Python
在Sublime Editor中配置Python环境的详细教程
May 03 Python
如何使用PyCharm及常用配置详解
Jun 03 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
高亮度显示php源代码
2006/10/09 PHP
php使用json_encode对变量json编码
2014/04/07 PHP
php数组合并array_merge()函数使用注意事项
2014/06/19 PHP
PHP中返回引用类型的方法
2015/04/03 PHP
PHP抽象类和接口用法实例详解
2019/07/20 PHP
验证用户是否修改过页面的数据的实现方法
2008/09/26 Javascript
js实现带搜索功能的下拉框实时搜索实时匹配
2013/11/05 Javascript
jquery html动态生成select标签出问题的解决方法
2013/11/20 Javascript
JS基于clipBoard.js插件实现剪切、复制、粘贴
2016/05/03 Javascript
jacascript DOM节点——元素节点、属性节点、文本节点
2017/04/18 Javascript
js 简易版滚动条实例(适用于移动端H5开发)
2017/06/26 Javascript
jQuery实现鼠标响应式透明度渐变动画效果示例
2018/02/13 jQuery
vue2 全局变量的设置方法
2018/03/09 Javascript
jQuery实现下拉菜单动态添加数据点击滑出收起其他功能
2018/06/14 jQuery
详解小程序输入框闪烁及重影BUG解决方案
2018/08/31 Javascript
优化Vue项目编译文件大小的方法步骤
2019/05/27 Javascript
vue实现树形结构样式和功能的实例代码
2019/10/15 Javascript
微信小程序实现录制、试听、上传音频功能(带波形图)
2020/02/27 Javascript
JS实现图片懒加载(lazyload)过程详解
2020/04/02 Javascript
[00:12]2018DOTA2亚洲邀请赛 Sccc亮相SOLO赛,今年他又会有什么样的战绩?
2018/04/06 DOTA
Python3里的super()和__class__使用介绍
2015/04/23 Python
Python判断文件和文件夹是否存在的方法
2015/05/21 Python
Python爬虫工程师面试问题总结
2018/03/22 Python
Linux下python3.6.1环境配置教程
2018/09/26 Python
解决python replace函数替换无效问题
2020/01/18 Python
TensorFlow 多元函数的极值实例
2020/02/10 Python
使用Keras预训练模型ResNet50进行图像分类方式
2020/05/23 Python
Python getattr()函数使用方法代码实例
2020/08/10 Python
CSS3 按钮边框动画的实现
2020/11/12 HTML / CSS
美国女士时尚珠宝及配饰购物网站:Icing
2018/07/02 全球购物
简述安装Slackware Linux系统的过程
2012/05/08 面试题
小学毕业感言50字
2014/02/16 职场文书
领导班子党的群众路线对照检查材料
2014/09/25 职场文书
投资意向协议书
2015/01/29 职场文书
法院执行局工作总结
2015/08/11 职场文书
js前端设计模式优化50%表单校验代码示例
2022/06/21 Javascript