如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
以视频爬取实例讲解Python爬虫神器Beautiful Soup用法
Jan 20 Python
Python中使用装饰器来优化尾递归的示例
Jun 18 Python
Django实现分页功能
Jul 02 Python
Python和Java的语法对比分析语法简洁上python的确完美胜出
May 10 Python
pandas通过字典生成dataframe的方法步骤
Jul 23 Python
Python中函数的返回值示例浅析
Aug 28 Python
python输出数组中指定元素的所有索引示例
Dec 06 Python
简单了解python shutil模块原理及使用方法
Apr 28 Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 Python
Python selenium键盘鼠标事件实现过程详解
Jul 28 Python
Python定时任务框架APScheduler原理及常用代码
Oct 05 Python
python从PDF中提取数据的示例
Oct 30 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
php面向对象全攻略 (六)__set() __get() __isset() __unset()的用法
2009/09/30 PHP
php中static静态变量的使用方法详解
2010/06/04 PHP
php设计模式 Singleton(单例模式)
2011/06/26 PHP
php中用加号与用array_merge合并数组的区别深入分析
2013/06/03 PHP
Yii实现MySQL多数据库和读写分离实例分析
2014/12/03 PHP
fckeditor上传文件按日期存放及重命名方法
2015/05/22 PHP
用php和jQuery来实现“顶”和“踩”的投票功能
2016/10/13 PHP
PHP解压ZIP文件到指定文件夹的方法
2016/11/17 PHP
深入理解 PHP7 中全新的 zval 容器和引用计数机制
2018/10/15 PHP
javascript的事件触发器介绍的实现
2014/06/05 Javascript
jquery插件jSignature实现手动签名
2015/05/04 Javascript
基于JQuery实现仿网易邮箱全屏动感滚动插件fullPage
2015/09/20 Javascript
常用javascript表单验证汇总
2020/07/20 Javascript
JavaScript中两个字符串的匹配
2016/06/08 Javascript
JavaScript 链式结构序列化详解
2016/09/30 Javascript
微信小程序 教程之小程序配置
2016/10/17 Javascript
node实现基于token的身份验证
2018/04/09 Javascript
详解React项目中碰到的IE问题
2019/03/14 Javascript
微信小程序在其他页面监听globalData中值的变化
2019/07/15 Javascript
vue 解决数组赋值无法渲染在页面的问题
2019/10/28 Javascript
jQuery实现简单QQ聊天框
2020/08/27 jQuery
Vue单页面应用中实现Markdown渲染
2021/02/14 Vue.js
[03:46]DOTA2英雄基础教程 维萨吉
2013/12/11 DOTA
[00:10]DOTA2 TI9勇士令状明日上线
2019/05/07 DOTA
Linux下Python获取IP地址的代码
2014/11/30 Python
解决Python一行输出不显示的问题
2018/12/03 Python
Python Excel处理库openpyxl使用详解
2019/05/09 Python
Python中生成ndarray实例讲解
2021/02/22 Python
波兰家居和花园家具专家:4Home
2019/05/26 全球购物
盖尔斯工厂店:GUESS Factory
2020/01/21 全球购物
小学生成长感言
2014/01/30 职场文书
工程专业应届生求职信
2014/02/19 职场文书
煤矿安全知识竞赛活动总结
2014/07/07 职场文书
纪念九一八事变演讲稿1000字
2014/09/14 职场文书
税务职业生涯规划书范文
2014/09/16 职场文书
毕业班班主任工作总结2015
2015/07/23 职场文书