TensorFlow学习之分布式的TensorFlow运行环境


Posted in Python onFebruary 05, 2020

当我们在大型的数据集上面进行深度学习的训练时,往往需要大量的运行资源,而且还要花费大量时间才能完成训练。

1.分布式TensorFlow的角色与原理

在分布式的TensorFlow中的角色分配如下:

PS:作为分布式训练的服务端,等待各个终端(supervisors)来连接。

worker:在TensorFlow的代码注释中被称为终端(supervisors),作为分布式训练的计算资源终端。

chief supervisors:在众多的运算终端中必须选择一个作为主要的运算终端。该终端在运算终端中最先启动,它的功能是合并各个终端运算后的学习参数,将其保存或者载入。

每个具体的网络标识都是唯一的,即分布在不同IP的机器上(或者同一个机器的不同端口)。在实际的运行中,各个角色的网络构建部分代码必须100%的相同。三者的分工如下:

服务端作为一个多方协调者,等待各个运算终端来连接。

chief supervisors会在启动时同一管理全局的学习参数,进行初始化或者从模型载入。

其他的运算终端只是负责得到其对应的任务并进行计算,并不会保存检查点,用于TensorBoard可视化中的summary日志等任何参数信息。

在整个过程都是通过RPC协议来进行通信的。

2.分布部署TensorFlow的具体方法

配置过程中,首先建立一个server,在server中会将ps及所有worker的IP端口准备好。接着,使用tf.train.Supervisor中的managed_ssion来管理一个打开的session。session中只是负责运算,而通信协调的事情就都交给supervisor来管理了。

3.部署训练实例

下面开始实现一个分布式训练的网络模型,以线性回归为例,通过3个端口来建立3个终端,分别是一个ps,两个worker,实现TensorFlow的分布式运算。

1. 为每个角色添加IP地址和端口,创建sever,在一台机器上开3个不同的端口,分别代表PS,chief supervisor和worker。角色的名称用strjob_name表示,以ps为例,代码如下:

# 定义IP和端口
strps_hosts = 'localhost:1681'
strworker_hosts = 'localhost:1682,localhost:1683'
# 定义角色名称
strjob_name = 'ps'
task_index = 0
# 将字符串转数组
ps_hosts = strps_hosts.split(',')
worker_hosts = strps_hosts.split(',')
cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts})
# 创建server
server = tf.train.Server({'ps':ps_hosts, 'worker':worker_hosts}, job_name=strjob_name, task_index=task_index)

2为ps角色添加等待函数

ps角色使用server.join函数进行线程挂起,开始接受连续消息。

# ps角色使用join进行等待
if strjob_name == 'ps':
  print("wait")
  server.join()

3.创建网络的结构

与正常的程序不同,在创建网络结构时,使用tf.device函数将全部的节点都放在当前任务下。在tf.device函数中的任务是通过tf.train.replica_device_setter来指定的。在tf.train.replica_device_setter中使用worker_device来定义具体任务名称;使用cluster的配置来指定角色及对应的IP地址,从而实现管理整个任务下的图节点。代码如下:

with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:%d'%task_index,
                       cluster=cluster_spec)):
  X = tf.placeholder('float')
  Y = tf.placeholder('float')
  # 模型参数
  w = tf.Variable(tf.random_normal([1]), name='weight')
  b = tf.Variable(tf.zeros([1]), name='bias')
  global_step = tf.train.get_or_create_global_step()  # 获取迭代次数
  z = tf.multiply(X, w) + b
  tf.summary('z', z)
  cost = tf.reduce_mean(tf.square(Y - z))
  tf.summary.scalar('loss_function', cost)
  learning_rate = 0.001
  optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost, global_step=global_step)
  saver = tf.train.Saver(max_to_keep=1)
  merged_summary_op = tf.summary.merge_all() # 合并所有summary
  init = tf.global_variables_initializer()

4.创建Supercisor,管理session

在tf.train.Supervisor函数中,is_chief表明为是否为chief Supervisor角色,这里将task_index=0的worker设置成chief Supervisor。saver需要将保存检查点的saver对象传入。init_op表示使用初始化变量的函数。

training_epochs = 2000
display_step = 2
sv = tf.train.Supervisor(is_chief=(task_index == 0),# 0号为chief
             logdir='log/spuer/',
             init_op=init,
             summary_op=None,
             saver=saver,
             global_step=global_step,
             save_model_secs=5)
# 连接目标角色创建session
with sv.managed_session(saver.target) as sess:

5迭代训练

session中的内容与以前一样,直接迭代训练即可。由于使用了supervisor管理session,将使用sv.summary_computed函数来保存summary文件。

print('sess ok')
  print(global_step.eval(session=sess))
  for epoch in range(global_step.eval(session=sess), training_epochs*len(train_x)):
    for (x, y) in zip(train_x, train_y):
      _, epoch = sess.run([optimizer, global_step], feed_dict={X: x, Y: y})
      summary_str = sess.run(merged_summary_op, feed_dict={X: x, Y: y})
      sv.summary_computed(sess, summary_str, global_step=epoch)
      if epoch % display_step == 0:
        loss = sess.run(cost, feed_dict={X:train_x, Y:train_y})
        print("Epoch:", epoch+1, 'loss:', loss, 'W=', sess.run(w), w, 'b=', sess.run(b))
  print(' finished ')
  sv.saver.save(sess, 'log/linear/' + "sv.cpk", global_step=epoch)
sv.stop()

(1)在设置自动保存检查点文件后,手动保存仍然有效,

(2)在运行一半后,在运行supervisor时会自动载入模型的参数,不需要手动调用restore。

(3)在session中不需要进行初始化的操作。

6.建立worker文件

新建两个py文件,设置task_index分别为0和1,其他的部分和上述的代码相一致。

strjob_name = 'worker'
task_index = 1
strjob_name = 'worker'
task_index = 0

7.运行

我们分别启动写好的三个文件,在运行结果中,我们可以看到循环的次数不是连续的,显示结果中会有警告,这是因为在构建supervisor时没有填写local_init_op参数,该参数的含义是在创建worker实例时,初始化本地变量,上述代码中没有设置,系统会自动初始化,并给出警告提示。

分布运算的目的是为了提高整体运算速度,如果同步epoch的准确率需要牺牲总体运行速度为代价,自然很不合适。

在ps的文件中,它只是负责连接,并不参与运算。

总结

以上所述是小编给大家介绍的TensorFlow学习之分布式的TensorFlow运行环境,希望对大家有所帮助!!

Python 相关文章推荐
sqlalchemy对象转dict的示例
Apr 22 Python
深入理解python多进程编程
Jun 12 Python
Python3下错误AttributeError: ‘dict’ object has no attribute’iteritems‘的分析与解决
Jul 06 Python
python3 模拟登录v2ex实例讲解
Jul 13 Python
浅析Python中return和finally共同挖的坑
Aug 18 Python
python的格式化输出(format,%)实例详解
Jun 01 Python
python定向爬虫校园论坛帖子信息
Jul 23 Python
浅谈python 中的 type(), dtype(), astype()的区别
Apr 09 Python
Python 如何测试文件是否存在
Jul 31 Python
利用python3筛选excel中特定的行(行值满足某个条件/行值属于某个集合)
Sep 04 Python
python 用pandas实现数据透视表功能
Dec 21 Python
python开发飞机大战游戏
Jul 15 Python
TensorFlow MNIST手写数据集的实现方法
Feb 05 #Python
tensorflow之并行读入数据详解
Feb 05 #Python
tensorflow mnist 数据加载实现并画图效果
Feb 05 #Python
tensorflow 自定义损失函数示例代码
Feb 05 #Python
利用Tensorflow的队列多线程读取数据方式
Feb 05 #Python
Tensorflow 多线程与多进程数据加载实例
Feb 05 #Python
TensorFlow自定义损失函数来预测商品销售量
Feb 05 #Python
You might like
蝙蝠侠:侠影之谜
2020/03/04 欧美动漫
Terran建筑一览
2020/03/14 星际争霸
php 输出双引号"与单引号'的方法
2010/05/09 PHP
php实现的DateDiff和DateAdd时间函数代码分享
2014/08/16 PHP
PHP实现的sqlite数据库连接类
2014/12/12 PHP
php+mysql实现无限分类实例详解
2015/01/15 PHP
php连接oracle数据库的核心步骤
2016/05/26 PHP
php+Ajax处理xml与json格式数据的方法示例
2019/03/04 PHP
THINKPHP-Apache服务器中使用Alias虚拟目录URL重写 隐藏index.php
2021/03/09 PHP
IE和Firefox下javascript的兼容写法小结
2008/12/10 Javascript
js对数字的格式化使用说明
2011/01/12 Javascript
jQuery autocomplate 自扩展插件、自动完成示例代码
2011/03/28 Javascript
关于query Javascript CSS Selector engine
2013/04/12 Javascript
jquery实现图片滚动效果的简单实例
2013/11/23 Javascript
js document.write()使用介绍
2014/02/21 Javascript
Javascript中获取对象的原型对象的方法小结
2015/02/25 Javascript
jQuery简单实现验证邮箱格式
2015/07/15 Javascript
JS遍历数组和对象的区别及递归遍历对象、数组、属性的方法详解
2016/06/14 Javascript
jQuery下拉菜单的实现代码
2016/11/03 Javascript
JavaScript中的call和apply的用途以及区别
2017/01/11 Javascript
微信小程序后台解密用户数据实例详解
2017/06/28 Javascript
Vue登录注册并保持登录状态的方法
2018/08/17 Javascript
JavaScript使用递归和循环实现阶乘的实例代码
2018/08/28 Javascript
postman自定义函数实现 时间函数的思路详解
2019/04/17 Javascript
python在多玩图片上下载妹子图的实现代码
2013/08/13 Python
python3新特性函数注释Function Annotations用法分析
2016/07/28 Python
详解Python核心编程中的浅拷贝与深拷贝
2018/01/07 Python
pandas的object对象转时间对象的方法
2018/04/11 Python
Python使用Pandas读写Excel实例解析
2019/11/19 Python
ReVive利维肤美国官网:RéVive Skincare
2018/04/18 全球购物
建筑总经理岗位职责
2014/02/02 职场文书
2014年入党积极分子党校培训心得体会
2014/07/08 职场文书
经典导游欢迎词
2015/01/26 职场文书
教师辞职书范文
2015/02/26 职场文书
电影圆明园观后感
2015/06/03 职场文书
从原生JavaScript到React深入理解
2022/07/23 Javascript