如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python获取电脑硬件信息及状态的实现方法
Aug 29 Python
python实现通过shelve修改对象实例
Sep 26 Python
解析Python编程中的包结构
Oct 25 Python
Python的pycurl包用法简介
Nov 13 Python
深入浅出分析Python装饰器用法
Jul 28 Python
Python程序员面试题 你必须提前准备!
Jan 16 Python
详解python while 函数及while和for的区别
Sep 07 Python
Python判断有效的数独算法示例
Feb 23 Python
python3人脸识别的两种方法
Apr 25 Python
Django为窗体加上防机器人的验证码功能过程解析
Aug 14 Python
python 根据网易云歌曲的ID 直接下载歌曲的实例
Aug 24 Python
Python如何使用ConfigParser读取配置文件
Nov 12 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
PHP闭包实例解析
2014/09/08 PHP
php简单日历函数
2015/10/28 PHP
WordPress中用于检索模版的相关PHP函数使用解析
2015/12/15 PHP
浅谈PHP定义命令空间的几个注意点(推荐)
2016/10/29 PHP
在html页面中包含共享页面的方法
2008/10/24 Javascript
微博@符号的用户名提示效果。(想@到谁?)
2010/11/05 Javascript
再谈javascript图片预加载技术(详细演示)
2011/03/12 Javascript
JS操作图片(增,删,改) 例子
2013/04/17 Javascript
用jquery仿做发微博功能示例
2014/04/18 Javascript
js实现的tab标签切换效果代码分享
2015/08/25 Javascript
AngularJS延迟加载html template
2016/07/27 Javascript
利用JS提交表单的几种方法和验证(必看篇)
2016/09/17 Javascript
基于jQuery Easyui实现登陆框界面
2017/07/10 jQuery
表格展示利器 Bootstrap Table实例代码
2017/09/06 Javascript
vue 中this.$set 动态绑定数据的案例讲解
2021/01/29 Vue.js
[22:20]初生之犊-TI4第5名LGD战队纪录片
2014/08/13 DOTA
使用Python生成XML的方法实例
2017/03/21 Python
python批量解压zip文件的方法
2019/08/20 Python
Python中的list与tuple集合区别解析
2019/10/12 Python
Python高级特性之闭包与装饰器实例详解
2019/11/19 Python
pycharm运行scrapy过程图解
2019/11/22 Python
解决Tensorboard 不显示计算图graph的问题
2020/02/15 Python
python 截取XML中bndbox的坐标中的图像,另存为jpg的实例
2020/03/10 Python
python中shell执行知识点
2020/05/06 Python
PyCharm2019 安装和配置教程详解附激活码
2020/07/31 Python
奥地利网上书店:Weltbild
2017/07/14 全球购物
向全球直邮输送天然健康产品:iHerb.com
2020/05/03 全球购物
关键字final的用法
2013/10/02 面试题
Java里面Pass by value和Pass by Reference是什么意思
2016/05/02 面试题
加拿大探亲邀请信
2014/01/28 职场文书
大队委竞选演讲稿
2014/04/28 职场文书
竞选生活委员演讲稿
2014/04/28 职场文书
1亿有多大教学反思
2014/05/01 职场文书
超市商业计划书
2014/05/04 职场文书
大学迎新标语
2014/06/26 职场文书
Python文件的操作示例的详细讲解
2021/04/08 Python