如何在keras中添加自己的优化器(如adam等)


Posted in Python onJune 19, 2020

本文主要讨论windows下基于tensorflow的keras

1、找到tensorflow的根目录

如果安装时使用anaconda且使用默认安装路径,则在 C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow处可以找到(此处为GPU版本),cpu版本可在C:\ProgramData\Anaconda3\Lib\site-packages\tensorflow处找到。若并非使用默认安装路径,可参照根目录查看找到。

2、找到keras在tensorflow下的根目录

需要特别注意的是找到keras在tensorflow下的根目录而不是找到keras的根目录。一般来说,完成tensorflow以及keras的配置后即可在tensorflow目录下的python目录中找到keras目录,以GPU为例keras在tensorflow下的根目录为C:\ProgramData\Anaconda3\envs\tensorflow-gpu\Lib\site-packages\tensorflow\python\keras

3、找到keras目录下的optimizers.py文件并添加自己的优化器

找到optimizers.py中的adam等优化器类并在后面添加自己的优化器类

以本文来说,我在第718行添加如下代码

@tf_export('keras.optimizers.adamsss')
class Adamsss(Optimizer):

 def __init__(self,
  lr=0.002,
  beta_1=0.9,
  beta_2=0.999,
  epsilon=None,
  schedule_decay=0.004,
  **kwargs):
 super(Adamsss, self).__init__(**kwargs)
 with K.name_scope(self.__class__.__name__):
 self.iterations = K.variable(0, dtype='int64', name='iterations')
 self.m_schedule = K.variable(1., name='m_schedule')
 self.lr = K.variable(lr, name='lr')
 self.beta_1 = K.variable(beta_1, name='beta_1')
 self.beta_2 = K.variable(beta_2, name='beta_2')
 if epsilon is None:
 epsilon = K.epsilon()
 self.epsilon = epsilon
 self.schedule_decay = schedule_decay

 def get_updates(self, loss, params):
 grads = self.get_gradients(loss, params)
 self.updates = [state_ops.assign_add(self.iterations, 1)]

 t = math_ops.cast(self.iterations, K.floatx()) + 1

 # Due to the recommendations in [2], i.e. warming momentum schedule
 momentum_cache_t = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay)))
 momentum_cache_t_1 = self.beta_1 * (
 1. - 0.5 *
 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay)))
 m_schedule_new = self.m_schedule * momentum_cache_t
 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1
 self.updates.append((self.m_schedule, m_schedule_new))

 shapes = [K.int_shape(p) for p in params]
 ms = [K.zeros(shape) for shape in shapes]
 vs = [K.zeros(shape) for shape in shapes]

 self.weights = [self.iterations] + ms + vs

 for p, g, m, v in zip(params, grads, ms, vs):
 # the following equations given in [1]
 g_prime = g / (1. - m_schedule_new)
 m_t = self.beta_1 * m + (1. - self.beta_1) * g
 m_t_prime = m_t / (1. - m_schedule_next)
 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g)
 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t))
 m_t_bar = (
  1. - momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime

 self.updates.append(state_ops.assign(m, m_t))
 self.updates.append(state_ops.assign(v, v_t))

 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon)
 new_p = p_t

 # Apply constraints.
 if getattr(p, 'constraint', None) is not None:
 new_p = p.constraint(new_p)

 self.updates.append(state_ops.assign(p, new_p))
 return self.updates

 def get_config(self):
 config = {
 'lr': float(K.get_value(self.lr)),
 'beta_1': float(K.get_value(self.beta_1)),
 'beta_2': float(K.get_value(self.beta_2)),
 'epsilon': self.epsilon,
 'schedule_decay': self.schedule_decay
 }
 base_config = super(Adamsss, self).get_config()
 return dict(list(base_config.items()) + list(config.items()))

然后修改之后的优化器调用类添加我自己的优化器adamss

需要修改的有(下面的两处修改依旧在optimizers.py内)

# Aliases.

sgd = SGD
rmsprop = RMSprop
adagrad = Adagrad
adadelta = Adadelta
adam = Adam
adamsss = Adamsss
adamax = Adamax
nadam = Nadam

以及

def deserialize(config, custom_objects=None):
 """Inverse of the `serialize` function.

 Arguments:
 config: Optimizer configuration dictionary.
 custom_objects: Optional dictionary mapping
  names (strings) to custom objects
  (classes and functions)
  to be considered during deserialization.

 Returns:
 A Keras Optimizer instance.
 """
 if tf2.enabled():
 all_classes = {
 'adadelta': adadelta_v2.Adadelta,
 'adagrad': adagrad_v2.Adagrad,
 'adam': adam_v2.Adam,
		'adamsss': adamsss_v2.Adamsss,
 'adamax': adamax_v2.Adamax,
 'nadam': nadam_v2.Nadam,
 'rmsprop': rmsprop_v2.RMSprop,
 'sgd': gradient_descent_v2.SGD
 }
 else:
 all_classes = {
 'adadelta': Adadelta,
 'adagrad': Adagrad,
 'adam': Adam,
 'adamax': Adamax,
 'nadam': Nadam,
		'adamsss': Adamsss,
 'rmsprop': RMSprop,
 'sgd': SGD,
 'tfoptimizer': TFOptimizer
 }

这里我们并没有v2版本,所以if后面的部分不改也可以。

4、调用我们的优化器对模型进行设置

model.compile(loss = 'crossentropy', optimizer = 'adamss', metrics=['accuracy'])

5、训练模型

train_history = model.fit(x, y_label, validation_split = 0.2, epoch = 10, batch = 128, verbose = 1)

补充知识:keras设置学习率--优化器的用法

优化器的用法

优化器 (optimizer) 是编译 Keras 模型的所需的两个参数之一:

from keras import optimizers
 
model = Sequential()
model.add(Dense(64, kernel_initializer='uniform', input_shape=(10,)))
model.add(Activation('softmax'))
 
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

你可以先实例化一个优化器对象,然后将它传入 model.compile(),像上述示例中一样, 或者你可以通过名称来调用优化器。在后一种情况下,将使用优化器的默认参数。

# 传入优化器名称: 默认参数将被采用
model.compile(loss='mean_squared_error', optimizer='sgd')

以上这篇如何在keras中添加自己的优化器(如adam等)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python的chardet库获得文件编码并修改编码
Jan 22 Python
python写的ARP攻击代码实例
Jun 04 Python
Python数据类型详解(一)字符串
May 08 Python
python中函数总结之装饰器闭包详解
Jun 12 Python
利用python解决mysql视图导入导出依赖的问题
Dec 17 Python
python使用BeautifulSoup与正则表达式爬取时光网不同地区top100电影并对比
Apr 15 Python
解决Django Static内容不能加载显示的问题
Jul 28 Python
python虚拟环境的安装和配置(virtualenv,virtualenvwrapper)
Aug 09 Python
matlab灰度图像调整及imadjust函数的用法详解
Feb 27 Python
Python如何省略括号方法详解
Mar 21 Python
详解Python3.8+PyQt5+pyqt5-tools+Pycharm配置详细教程
Nov 02 Python
使用BeautifulSoup4解析XML的方法小结
Dec 07 Python
详解pyinstaller生成exe的闪退问题解决方案
Jun 19 #Python
Python实现爬取并分析电商评论
Jun 19 #Python
keras 实现轻量级网络ShuffleNet教程
Jun 19 #Python
Python爬虫实现HTTP网络请求多种实现方式
Jun 19 #Python
Keras设置以及获取权重的实现
Jun 19 #Python
Python包和模块的分发详细介绍
Jun 19 #Python
浅谈Keras中shuffle和validation_split的顺序
Jun 19 #Python
You might like
VOLVO车载收音机
2021/03/02 无线电
php面向对象全攻略 (九)访问类型
2009/09/30 PHP
php数组删除元素示例
2014/03/21 PHP
PHP调用其他文件中的类
2018/04/02 PHP
PHP中抽象类,接口功能、定义方法示例
2019/02/26 PHP
TP5框架页面跳转样式操作示例
2020/04/05 PHP
PHP Web表单生成器案例分析
2020/06/02 PHP
如何通过Apache在本地配置多个虚拟主机
2020/07/29 PHP
JQury slideToggle闪烁问题及解决办法
2011/07/05 Javascript
分享精心挑选的23款美轮美奂的jQuery 图片特效插件
2012/08/14 Javascript
JavaScript动态创建div属性和样式示例代码
2013/10/09 Javascript
jquery插件jSignature实现手动签名
2015/05/04 Javascript
jquery性能优化高级技巧
2015/08/24 Javascript
详解AngularJS Filter(过滤器)用法
2015/12/28 Javascript
获取input标签的所有属性的方法
2016/06/28 Javascript
Javascript使用SWFUpload进行多文件上传
2016/11/16 Javascript
angularjs实现猜数字大小功能
2020/05/20 Javascript
bootstrap实现二级下拉菜单效果
2017/11/23 Javascript
nodejs实现简单的gulp打包
2017/12/21 NodeJs
微信小程序 拍照或从相册选取图片上传代码实例
2019/08/28 Javascript
vue表单中遍历表单操作按钮的显示隐藏示例
2019/10/30 Javascript
pydev使用wxpython找不到路径的解决方法
2013/02/10 Python
python中使用urllib2获取http请求状态码的代码例子
2014/07/07 Python
Python中优化NumPy包使用性能的教程
2015/04/23 Python
Python中的ceil()方法使用教程
2015/05/14 Python
完美解决Python 2.7不能正常使用pip install的问题
2018/06/12 Python
python 利用pyttsx3文字转语音过程详解
2019/09/25 Python
torch 中各种图像格式转换的实现方法
2019/12/26 Python
Myprotein中国网站:欧洲畅销运动营养品牌
2021/02/11 全球购物
mysql_pconnect()和mysql_connect()有什么区别
2012/05/25 面试题
计算s=f(f(-1.4))的值
2014/05/06 面试题
电大学习个人自我评价范文
2013/10/04 职场文书
数据员岗位职责
2013/11/19 职场文书
惹女朋友生气检讨书
2015/05/06 职场文书
农业项目投资意向书
2015/05/09 职场文书
2015银行年终工作总结范文
2015/05/26 职场文书