如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用java模块SmartXLS和jpype修改excel文件的方法
Apr 28 Python
python批量替换页眉页脚实例代码
Jan 22 Python
python3学习之Splash的安装与实例教程
Jul 09 Python
python实现beta分布概率密度函数的方法
Jul 08 Python
python 动态调用函数实例解析
Oct 21 Python
浅谈django框架集成swagger以及自定义参数问题
Jul 07 Python
Python在字符串中处理html和xml的方法
Jul 31 Python
Python3合并两个有序数组代码实例
Aug 11 Python
Python 处理日期时间的Arrow库使用
Aug 18 Python
Django模型验证器介绍与源码分析
Sep 08 Python
如何在python中处理配置文件代码实例
Sep 27 Python
python Pexpect模块的使用
Dec 25 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
剖析 PHP 中的输出缓冲
2006/12/21 PHP
PHP实现采集中国天气网未来7天天气
2014/10/15 PHP
PHP中的Session对象如何使用
2015/09/25 PHP
JavaScript 编程引入命名空间的方法与代码
2007/08/13 Javascript
safari,opera嵌入iframe页面cookie读取问题解决方法
2010/06/23 Javascript
js操作CheckBoxList实现全选/反选(在客服端完成)
2013/02/02 Javascript
jQuery插件实现表格隔行换色且感应鼠标高亮行变色
2013/09/22 Javascript
createTextRange()的使用示例含文本框选中部分文字内容
2014/02/24 Javascript
javascript结合Flexbox简单实现滑动拼图游戏
2016/02/18 Javascript
JS代码防止SQL注入的方法(超简单)
2016/04/12 Javascript
全面理解JavaScript中的继承(必看)
2016/06/16 Javascript
angularjs 表单密码验证自定义指令实现代码
2016/10/27 Javascript
给easyui datebox扩展一个清空的实例
2016/11/09 Javascript
基于jquery实现二级联动效果
2017/03/30 jQuery
JavaScript中闭包的详解
2017/04/01 Javascript
vue 中自定义指令改变data中的值
2017/06/02 Javascript
JavaScript无操作后屏保功能的实现方法
2017/07/04 Javascript
JavaScript闭包和回调详解
2017/08/09 Javascript
微信小程序实现YDUI的ScrollNav组件
2018/02/02 Javascript
Vue filter格式化时间戳时间成标准日期格式的方法
2018/09/16 Javascript
vue-router路由模式详解(小结)
2019/08/26 Javascript
浅谈scrapy 的基本命令介绍
2017/06/13 Python
windows下 兼容Python2和Python3的解决方法
2018/12/05 Python
python selenium实现发送带附件的邮件代码实例
2019/12/10 Python
python UDF 实现对csv批量md5加密操作
2021/01/01 Python
Bodum官网:咖啡和茶壶、玻璃器皿、厨房电器等
2018/08/01 全球购物
this关键字的含义
2015/04/08 面试题
主管职责范文
2013/11/09 职场文书
缅怀先烈演讲稿
2014/09/03 职场文书
后进生评语大全
2015/01/04 职场文书
合理化建议书
2015/02/04 职场文书
年度考核个人总结
2015/03/06 职场文书
2015年民主评议党员工作总结
2015/05/19 职场文书
二审代理词范文
2015/05/25 职场文书
2016大一新生军训心得体会
2016/01/11 职场文书
微软Win11什么功能最惊艳? Windows11新功能特性汇总
2021/11/21 数码科技