自定义命令行参数
- 定义命令行参数的方法
- tf.app.flags 支持应从命令行接收参数, 可以用来指定集群的配置等, 有下面的参数类型
- DEFINE_string(flag_name, default_value, docstring)
- DEFINE_integer(flag_name, default_value, docstring)
- DEFINE_boolean(flag_name, default_value, docstring)
- DEFINE_float(flag_name, default_value, docstring)
- 调用命令行参数的方法
- tf.app.flags.FlAGS.xxx
- xxx 为上面定义的flag_name
- 通过命令行启动的方法
- tf.app.run() 可以截止启动main(argv) 函数
- 可以在命令行中启动并传入参数, 程序可以直接帮我们运行main()函数
分布式TensorFlow
- 分布式TensowFlow介绍
- Tensorflow的一个特色就是分布式计算。分布式Tensorflow是由高性能的gRPC框架作为底层技术来支持的。这是一个通信框架gRPC(google remote procedure call),是一个高性能、跨平台的RPC框架。RPC协议,即远程过程调用协议,是指通过网络从远程计算机程序上请求服务。
- 分布式原理
- Tensorflow分布式是由多个服务器进程和客户端进程组成。有几种部署方式,列如单机多卡和多机多卡(分布式)。
- 单机多卡(单台服务器有多块GPU设备)
- 在单机单GPU的训练中,数据是一个batch一个batch的训练。 在单机多GPU中,数据一次处理4个batch(假设是4个GPU训练), 每个GPU处理一个batch的数据计算。
- 变量,或者说参数,保存在CPU上。数据由CPU分发给4个GPU,在GPU上完成计算,得到每个批次要更新的梯度
- 在CPU上收集完4个GPU上要更新的梯度,计算一下平均梯度,然后更新。
- 循环进行上面步骤
- 多机多卡
- 而分布式是指有多台计算机,充分使用多台计算机的性能,处理数据的能力。可以根据不同计算机划分不同的工作节点。当数据量或者计算量达到超过一台计算机处理能力的上上限的话,必须使用分布式
- 分布式架构
- 当我们知道的基本的分布式原理之后,我们来看看分布式的架构的组成。分布式架构的组成可以说是一个集群的组成方式。那么一般我们在进行Tensorflow分布式时,需要建立一个集群。通常是我们分布式的作业集合。一个作业中又包含了很多的任务(工作结点),每个任务由一个工作进程来执行。
- 分布式节点之间的关系
- 一般来说,在分布式机器学习框架中,我们会把作业分成参数作业(parameter job)和工作结点作业(worker job)。
- 参数服务器:
- 运行参数作业的服务器我们称之为参数服务器(parameter server,PS),负责管理参数的存储和更新
- 参数服务器,当模型越来越大时,模型的参数越来越多,多到一台机器的性能不够完成对模型参数的更新的时候,就需要把参数分开放到不同的机器去存储和更新。参数服务器可以是由多台机器组成的集群。
- worker服务器:
- 工作结点作业负责主要从事计算的任务,如运行操作。
- Tensorflow的分布式实现了作业间的数据传输,也就是参数作业到工作结点作业的前向传播,以及工作节点到参数作业的反向传播。
- 所有的Worker服务器会有一个主worker服务器, 会话的创建, 文件的读取和存储只会在主worker上进行
分布式的模式
- 数据并行
- 原理
- 数据并总的原理很简单。其中CPU主要负责梯度平均和参数更新,而GPU主要负责训练模型副本。
- 实现
- 模型副本定义在GPU上
- 对于每一个GPU,都是从CPU获得数据,前向传播进行计算,得到损失,并计算出梯度
- CPU接到GPU的梯度,取平均值,然后进行梯度更新
- 存在的问题
- 每一个设备的计算速度不一样,有的快有的慢,那么CPU在更新变量的时候,就应该等待每一个设备的一个batch进行完成,这样相当于降低了分布式的效率,解决办法请参考下面的异步更新
- 同步更新&异步更新
- 同步更新(随机梯度下降法(Sync-SGD))
- 同步随即梯度下降法的含义是在进行训练时,每个节点的工作任务需要读入共享参数,执行并行的梯度计算,同步需要等待所有工作节点把局部的梯度算好,然后将所有共享参数进行合并、累加,再一次性更新到模型的参数;下一个批次中,所有工作节点拿到模型更新后的参数再进行训练。这种方案的优势是,每个训练批次都考虑了所有工作节点的训练情况,损失下降比较稳定;劣势是,性能瓶颈在于最慢的工作结点上。
- 异步更新(异步随机梯度下降法(Async-SGD))
- 异步随机梯度下降法的含义是每个工作结点上的任务独立计算局部梯度,并异步更新到模型的参数中,不需要执行协调和等待操作。这种方案的优势是,性能不存在瓶颈;劣势是,每个工作节点计算的梯度值发送回参数服务器会有参数更新的冲突,一定程度上会影响算法的收敛速度,在损失下降的过程中抖动较大。
分布式接口
- 创建分布式集群的方法
- 创建集群的方法是为每一个任务启动一个服务,这些任务可以分布在不同的机器上,也可以同一台机器上启动多个任务,使用不同的GPU等来运行。
- 创建分布式集群的步骤
- 1、创建一个tf.train.ClusterSpec,用于对集群中的所有任务进行描述,该描述内容对所有任务应该是相同的
- 2、创建tf.train.Server,用于创建一个任务
- 3、启动
- TensorFlow分布式API接口使用
- 1. tf.train.ClusterSpec( ) 创建ClusterSpec,表示参与分布式TensorFlow计算的一组进程
- 2. tf.train.Server( ) 创建Tensorflow的集群描述信息,其中ps和worker为作业名称,通过指定ip地址加端口创建,
- 创建server
- server = tf.train.Server(server_or_cluster_def, job_name=None, task_index=None, protocol=None, config=None, start=True)
- server_or_cluster_def: 集群描述
- job_name: 任务类型名称
- task_index: 任务数
- 使用server
- server.target
- 返回tf.Session连接到此服务器的目标
- server.join()
- 参数服务器端等待接受参数任务,直到服务器关闭
- 3. tf.device( ) 指定worker运行设备
- tf.device(device_name_or_function) 指定代码运行在CPU或者GPU上
- device_name:
- 例: /job:worker/task:0/cpu:0
- function:
- tf.train.reploca_device_setter(worker_device=worker_device, cluster=cluster)
- worker_device: 例: /job:worker/task:0/cpu:0
- cluster: 集群描述对象
- 示例:
- 4. sess = tf.train.MonitoredTrainingSession( ) 创建分布式会话
- 创建会话方法
- tf.train.MonitoredTrainingSession(master, is_chief=True, checkpoint_dir=None, hooks=None, save_checkpoint_secs=600, save_summaries_steps=USE_DEFAULT, save_summaries_secs=USE_DEFAULT, config-None)
- master: 指定运行会话协议IP和端口(用于分布式)
- 例: grpc://192.168.0.1:2000
- is_chief: 是否为主worker 主worker负责初始化和恢复基础的TensorFlow会话
- checkpoint_dir: 检查点文件记录, 同时也是events目录
- config: 会话运行的配置项, tf.ConfigProto(log_device_placement=True)
- hooks: 可选SessionRunHook 对象列表
- 会话的方法
- sess.should_stop( ) : 当程序发生异常的时候should_stop 返回True
- sess.run() 跟session一样可以运行op
- 分布式会话-钩子对象
- 钩子对象作用
- 当在开启分布式会话的时候, 钩子对象方法被调用
- 指定了钩子对象(包含了begin,before_run, after_run方法), 可以实现分别在初始化会话, 调用run方法之前, 调用run方法之后, 分别运行钩子对象对应的方法
- 该功能相当于django的中间件功能
- 可以用来实现计步, 打印中间输出等功能
- 创建钩子对象的方法
- 1. 创建类并继承tf.train.SessionRunHook
- 2. 实现方法 begin 会话之前调用, 只会调用一次
- 3. 实现方法 before_run(run_context) 接收run_context 参数
- run_context: 一个SessionRunContenxt对象, 包含会话的运行信息
- return: 一个SessionRunArgs对象, 例如: tf.train.SessionRunArgs(Tensow),接收的参数必须为Tensow类型
- 4. 实现方法 after_run(run_context, run_value)
- run_context: 一个SessionRunContext 对象
- run_values: 一个SessionRunValues对象, run_values.results 为before_run返回的Tensow对象的eval属性, 即tensow的值
- 常用的钩子对象
- tf.train.StopAtStepHook(last_stop) 计步钩子对象
- last_stop: 指定训练步数, 当达到的时候抛出异常
- 注意: 在使用计步钩子的时候需要定义全局步数:global_step = tf.contrib.framework.get_or_create_global_step()