如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python代码检查工具pylint 让你的python更规范
Sep 05 Python
Python struct模块解析
Jun 12 Python
python用10行代码实现对黄色图片的检测功能
Aug 10 Python
python学习入门细节知识点
Mar 29 Python
详解Django中类视图使用装饰器的方式
Aug 12 Python
python 自定义异常和异常捕捉的方法
Oct 18 Python
Python 面试中 8 个必考问题
Nov 16 Python
深度辨析Python的eval()与exec()的方法
Mar 26 Python
使用python serial 获取所有的串口名称的实例
Jul 02 Python
Python笔记之观察者模式
Nov 20 Python
详解Pycharm出现out of memory的终极解决方法
Mar 03 Python
用OpenCV进行年龄和性别检测的实现示例
Jan 29 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
PHP中设置时区,记录日志文件的实现代码
2013/01/07 PHP
PHP统计目录中文件以及目录中目录大小的方法
2016/01/09 PHP
PHP5.5.15+Apache2.4.10+MySQL5.6.20配置方法分享
2016/05/06 PHP
laravel 时间格式转时间戳的例子
2019/10/11 PHP
区分JS中的undefined,null,"",0和false
2007/03/08 Javascript
javascript温习的一些笔记 基础常用知识小结
2011/06/22 Javascript
IE6-IE9不支持table.innerHTML的解决方法分享
2012/09/14 Javascript
document.documentElement的一些使用技巧
2013/04/18 Javascript
让网页跳转到指定位置的jquery代码非书签
2013/09/06 Javascript
jquery 操作iframe的几种方法总结
2013/12/13 Javascript
javascript中定义私有方法说明(private method)
2014/01/27 Javascript
js小数运算出现多位小数如何解决
2015/10/08 Javascript
jquery实现树形菜单完整代码
2015/12/29 Javascript
使用jquery的jsonp如何发起跨域请求及其原理详解
2017/08/17 jQuery
vue生成随机验证码的示例代码
2017/09/29 Javascript
详解关闭令人抓狂的ESlint 语法检测配置方法
2019/10/28 Javascript
Vue项目中使用jsonp抓取跨域数据的方法
2019/11/10 Javascript
python中将阿拉伯数字转换成中文的实现代码
2011/05/19 Python
Python实现字典的key和values的交换
2015/08/04 Python
对Python subprocess.Popen子进程管道阻塞详解
2018/10/29 Python
判断Threading.start新线程是否执行完毕的实例
2020/05/02 Python
基于pycharm实现批量修改变量名
2020/06/02 Python
python基于pygame实现飞机大作战小游戏
2020/11/19 Python
matplotlib运行时配置(Runtime Configuration,rc)参数rcParams解析
2021/01/05 Python
html5音频_动力节点Java学院整理
2018/08/22 HTML / CSS
英国在线花园中心:You Garden
2018/06/03 全球购物
世界汽车零件:World Car Parts
2019/09/04 全球购物
最新大学生自我评价
2013/09/24 职场文书
校园活动策划书范文
2014/01/10 职场文书
安全生产管理责任书
2014/04/16 职场文书
社区领导班子四风问题原因分析及整改措施
2014/09/28 职场文书
党员教师自我剖析材料
2014/09/29 职场文书
2014年自愿离婚协议书
2014/10/10 职场文书
党员对照检查剖析材料
2014/10/13 职场文书
个人业务学习心得体会
2016/01/25 职场文书
应届毕业生的自我评价
2019/06/21 职场文书