如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Cpy和Python的效率对比
Mar 20 Python
Python函数可变参数定义及其参数传递方式实例详解
May 25 Python
浅谈Python peewee 使用经验
Oct 20 Python
python爬取各类文档方法归类汇总
Mar 22 Python
使用pandas的DataFrame的plot方法绘制图像的实例
May 24 Python
python 在某.py文件中调用其他.py内的函数的方法
Jun 25 Python
在django中实现页面倒数几秒后自动跳转的例子
Aug 16 Python
python绘制无向图度分布曲线示例
Nov 22 Python
tensorflow之自定义神经网络层实例
Feb 07 Python
用Python做一个久坐提醒小助手的示例代码
Feb 10 Python
Django缓存Cache使用详解
Nov 30 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
在php MYSQL中插入当前时间
2008/04/06 PHP
将word转化为swf 如同百度文库般阅读实现思路及代码
2013/08/09 PHP
PHP实现更新中间关联表数据的两种方法
2014/09/01 PHP
php解决DOM乱码的方法示例代码
2016/11/20 PHP
PHP实现的redis主从数据库状态检测功能示例
2017/07/20 PHP
javascript动画对象支持加速、减速、缓入、缓出的实现代码
2012/09/30 Javascript
用于deeplink的js方法(判断手机是否安装app)
2014/04/02 Javascript
jQuery操作表格(table)的常用方法、技巧汇总
2014/04/12 Javascript
JS回调函数的应用简单实例
2014/09/17 Javascript
PHP配置文件php.ini中打开错误报告的设置方法
2015/01/09 PHP
原创jQuery弹出层插件分享
2015/04/02 Javascript
基于jQuery倒计时插件实现团购秒杀效果
2016/05/13 Javascript
jQuery插件FusionCharts绘制的2D帕累托图效果示例【附demo源码】
2017/03/28 jQuery
JS实现的四级密码强度检测功能示例
2017/05/11 Javascript
nginx配置React静态页面的方法教程
2017/11/03 Javascript
JavaScript判断浏览器运行环境的详细方法
2019/06/30 Javascript
taro小程序添加骨架屏的实现代码
2019/11/15 Javascript
Python open读写文件实现脚本
2008/09/06 Python
极简的Python入门指引
2015/04/01 Python
python 添加用户设置密码并发邮件给root用户
2016/07/25 Python
Python实现的字典值比较功能示例
2018/01/08 Python
python的set处理二维数组转一维数组的方法示例
2019/05/31 Python
django的ORM操作 删除和编辑实现详解
2019/07/24 Python
Python 异步协程函数原理及实例详解
2019/11/13 Python
面向新手解析python Beautiful Soup基本用法
2020/07/11 Python
python代码实现图书管理系统
2020/11/30 Python
python 6种方法实现单例模式
2020/12/15 Python
CSS3制作轮播图的一种方法
2019/11/11 HTML / CSS
SHEIN香港:价格实惠的女性时尚服装
2018/08/14 全球购物
Sandro法国官网:法国成衣品牌
2019/08/28 全球购物
教你打造完美的创业计划书
2014/01/06 职场文书
探亲邀请信范文
2014/01/30 职场文书
坚定理想信念心得体会
2014/03/11 职场文书
古诗之感恩老师
2019/10/24 职场文书
Python答题卡识别并给出分数的实现代码
2021/06/22 Python
MySQL中B树索引和B+树索引的区别详解
2022/03/03 MySQL