如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python基于Tkinter实现的记事本实例
Jun 17 Python
Python模拟登录验证码(代码简单)
Feb 06 Python
python3.4用循环往mysql5.7中写数据并输出的实现方法
Jun 20 Python
python判断设备是否联网的方法
Jun 29 Python
Python清空文件并替换内容的实例
Oct 22 Python
Python通用循环的构造方法实例分析
Dec 19 Python
python 格式化输出百分号的方法
Jan 20 Python
Pytorch 中retain_graph的用法详解
Jan 07 Python
python矩阵运算,转置,逆运算,共轭矩阵实例
May 11 Python
如何基于python把文字图片写入word文档
Jul 31 Python
Python中Pyspider爬虫框架的基本使用详解
Jan 27 Python
python中使用asyncio实现异步IO实例分析
Feb 26 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
NO3第三帝国留言簿制作过程
2006/10/09 PHP
php简单判断文本编码的方法
2015/07/30 PHP
疯掉了,尽然有js写的操作系统
2007/04/23 Javascript
javascript数字格式化通用类 accounting.js使用
2012/08/24 Javascript
基于JavaScript实现根据手机定位获取当前具体位置(X省X市X县X街道X号)
2015/12/29 Javascript
Javascript 制作图形验证码实例详解
2016/12/22 Javascript
创建一般js对象的几种方式
2017/01/19 Javascript
AngularJS监听路由变化的方法
2017/03/07 Javascript
实例详解JSON取值(key是中文或者数字)方式
2017/08/24 Javascript
vue实现简单的星级评分组件源码
2018/11/16 Javascript
Vue实现类似Spring官网图片滑动效果方法
2019/03/01 Javascript
webpack中如何加载静态文件的方法步骤
2019/05/18 Javascript
HTML+JS实现“代码雨”效果源码(黑客帝国文字下落效果)
2020/03/17 Javascript
[33:23]VG vs Pain 2018国际邀请赛小组赛BO2 第二场 8.18
2018/08/19 DOTA
[35:39]完美世界DOTA2联赛PWL S2 FTD.C vs Rebirth 第二场 11.22
2020/11/24 DOTA
[52:03]DOTA2-DPC中国联赛 正赛 Ehome vs iG BO3 第三场 1月31日
2021/03/11 DOTA
关于python pyqt5安装失败问题的解决方法
2017/08/08 Python
Python 判断 有向图 是否有环的实例讲解
2018/02/01 Python
解决Python print 输出文本显示 gbk 编码错误问题
2018/07/13 Python
基于Python实现2种反转链表方法代码实例
2020/07/06 Python
CSS3 3D制作实战案例分析
2016/09/18 HTML / CSS
Get The Label中文官网:英国运动时尚购物平台
2017/04/19 全球购物
意大利在线眼镜精品店:Ottica Lipari
2019/11/11 全球购物
武汉东之林科技有限公司机试
2013/09/17 面试题
外语专业毕业生个人的自荐信
2013/11/19 职场文书
国贸专业个人求职信分享
2013/12/04 职场文书
寄语学生的话
2014/04/10 职场文书
国际贸易专业求职信
2014/06/04 职场文书
电焊工岗位工作职责
2014/07/09 职场文书
购房意向书
2014/08/30 职场文书
公司法定代表人授权委托书
2014/09/29 职场文书
科长个人四风问题整改措施思想汇报
2014/10/13 职场文书
幼儿园综治宣传月活动总结
2015/05/07 职场文书
2016年班主任培训心得体会
2016/01/07 职场文书
pytorch实现手写数字图片识别
2021/05/20 Python
详解如何用Python实现感知器算法
2021/06/18 Python