TensorFlow中的通信机制——Rendezvous(二)gRPC传输

背景

[作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor]

本篇是TensorFlow通信机制系列的第二篇文章,主要梳理使用gRPC网络传输部分模块的结构和源码。如果读者对TensorFlow中Rendezvous部分的基本结构和原理还不是非常了解,那么建议先从这篇文章开始阅读。TensorFlow在最初被开源时还只是个单机的异构训练框架,在迭代到0.8版本开始正式支持多机分布式训练。与其他分布式训练框架不同,Google选用了开源项目gRPC作为TensorFlow的跨机通信协议作为支持。gRPC的编程和使用其实是相对复杂的,TensorFlow为了能让gRPC的调用更加平滑,在调用链封装和抽象上面做了较多工作,甚至有些工作例如创建和管理gRPC channel涉及到了GrpcSession模块。从个人角度来看,利用gRPC进行Tensor通信的过程已经足够丰富,所以我们只针对gRPC传输Tensor过程进行梳理,至于涉及到gRPC管理方面的内容会在另一篇介绍分布式Session创建和管理的文章中集中梳理。

跨进程通信过程

根据之前写博客的经验,直接介绍类图结构和源码部分可能会让人懵圈,还是先从逻辑上把通信过程梳理清楚更能做到深入浅出。其实对于不是非常了解分布式系统或大规模并发系统的读者而言,TensorFlow中通信过程是有些“别扭”的。那么有的读者可能会觉得诧异,跨进程通信过程不就是一方做Send,另一方做Recv吗?这是一个理所当然的过程,为什么会“别扭”呢?是的,整个过程依然是一方做Send,另一方做Recv。而它的“别扭”之处就在于——真正的通信过程由Recv方触发,而不是Send方!这就是理解TensorFlow中使用gRPC传输Tensor过程的最关键点。

前一篇文章分析过在本地传输的场景下Tensor通信的大体过程,从机制和逻辑上来说,跨进程传输过程和本地传输没有很大的差异:TensorFlow使用Rendezvous通信Tensor,借助一个类似Table的数据结构作为传输的中转,并且Send方和Recv方依靠ParsedKey这一唯一传输标识符,跨进程通信也是如此。如果读者对这部分内容不了解,可以参考这篇文章。

Send方——将Ready的Tensor挂入本地Table

和本地传输场景下的Send过程相同,本地Tensor处于Ready状态后就被放挂了本地Worker的Table中,至此Send过程就全部完成了。所以Send过程完全没有涉及到任何跨网络传输的内容,并且Send过程是非阻塞的。

Recv方——向Send方主动发出请求,触发通信过程

Recv方是Tensor的接收方,它的处理过程是:将所需要的Tensor对应的ParsedKey拼出后,主动向Send方主动发出Request,Send方在接收到Request后立即在本地Table中查找方所需要的Tensor,找到后将Tensor封装成Response发送回Recv方。在这个过程中,Recv方可以认为是Client,Send方可以认为是Server,通过发送Request和Response来完成Tensor的传输。

结构设计解析

建议读者在阅读本节时适当翻开TensorFlow C++部分源码,但只需要理解结构关系即可(比如类之间的继承、组合、依赖关系),暂时不要阅读类的实现内容。因为RemoteRendezvous部分涉及到的类结构非常多,直接陷入细节的阅读会深陷其中不能自拔,甚至弄得一头雾水十分疲惫。在梳理结构时一边参照下文中的类图结构,一边从设计模式和架构的角度尝试去理解每个模块的司职是理解本篇细节的关键。先理解宏观结构看懂架子,再去深入理解实现细节尝试去优化是读任何代码的正确顺序。

任何场景下,通信过程几乎都是可以通过简单的图将功能描述清楚的。但是不可否认的是,任何涉及到分布式通信的系统在架构上都会对通信层做相对复杂的封装。一方面是因为通信虽然功能简单,但其实现本身具有相对较高的复杂性(大家可以尝试阅读gRPC源码感受下底层软件的复杂度)。另一方面,应用层也需要与通信底层通过抽象尽量实现较好的解耦,这样也方便将应用层模块被其他团队扩展编写。下面我们一起来探究TensorFlow中涉及到跨进程通信的Rendezvous系列。

两层抽象继承关系——RemoteRendezvous与BaseRemoteRendezvous

前一篇在介绍本地传输时我们熟悉了Rendezvous模块中与本地传输相关的类,例如LocalRendezvousImpl,IntraProcessRendezvous和SimpleRendezvous。对应地,跨进程传输也有不同的Rendezvous,从根源上来说,它们也继承于Rendezvous接口,并且不同的传输协议也有各自的Rendezvous。在这里,我们再次将前文中展示的总体类结构图展示出来,这次我们将涉及到远程传输的类用特殊颜色标出,如下图所示。

综合来看,从Rendezvous的继承结构来看,涉及到跨进程传输的Rendezvous有层:

1. RemoteRendezvous:只增加了一个Initialize方法,并标记为纯虚函数。这是因为跨进程Rendezvous需要借助Session做一些初始化工作,所以TensorFlow中所有涉及到跨进程通信的Rendezvous都需要重写Initialize函数,使用前也必须强制调用该函数。

2. 各种具体协议Rendezvous的基类——BaseRemoteRendezvous:既然所有涉及跨进程通信的Rendezvous都需要提供各自协议下实现的Initialize函数,那么没有比在RemoteRendezvous和真正特化的Rendezvous之间再添加一层继承关系更合适的做法了。事实上TensorFlow在此处也是这么设计的,这个承上启下的类就是BaseRemoteRendezvous。它还提供了公共的Send和Recv方法,这可以让继承它的特化Rendezvous尽最大可能做到代码复用。

BaseRecvTensorCall是通信的实体抽象,后面分析时会有更深的体会,在这里先有个印象即可。

开始特化——各种各样的RemoteRendezvous

TensorFlow目标是通用可扩展,所以被设计成允许底层支持多种通信协议的结构。事实上到目前为止,算上contrib目录的内容(contrib目录是广大TensorFlow贡献者添加的内容),TensorFlow已经支持包括gRPC,RDMA(Remote Direct Memroy Access),GDR(GPU Dirrect)和MPI四种通信协议,因此包含了四种对应的Rendezvous,他们分别是RpcRemoteRendezvous,RDMARemoteRendezvous,GdrRemoteRendezvous和MPIRemoteRendezvous。每种通信协议各有其特点,有时候其可用性也取决于硬件和软件条件(比如RDMA需要支持RDMA协议的网卡,通常跑在Infiniband和RoCE网络上,如果没有硬件支持,那么RDMA将无法使用,GDR也是这个道理)。从代码中可以看出,实现每种具体的RemoteRendezvous都有一定的复杂性,所以很难想象在没有封装抽象和代码复用的结构里如何实现这些内容。在本篇我们关注RpcRemoteRendezvous,它是gRPC协议实现的RemoteRendezvous。

令人熟悉的管理器模式——RendezvousMgr

为了更好地管理RemoteRendezvous,TensorFlow设计了相应的管理器——RendezvousMgr相关类,并为每种具体的RemoteRendevzous做了特化。熟悉设计模式的读者都知道,管理器是一种经典的设计模式,它能使管理职责的变化独立于类本身。RendezvousMgr主要负责RemoteRendezvous的创建和销毁,它也定义了两个本地版本的Recv接口。有的读者可能会问,管理器为什么还允许做Recv?并且只能做本地的Recv?我个人判断添加这两个接口纯粹是为了方便某些地方的使用。至于RendezvousMgr的创建时机和RemoteRendezvous的初始化过程并不是本篇解析的范畴,因为这涉及到分布式场景下创建Server的较长链路,这部分内容会在以后的博客中详细解析。下面是RendezvousMgr相关的类图结构,我们可以看到其接口类中已经定义了Recv接口。

RpcRemoteRendezvous通信过程与源码解析

上一小节中对RemoteRendezvous相关类结构和类间的关系做了解析,旨在从架构层面帮助读者理解各个类的职能。虽然涉及到的内容比较多,但是整体的结构和逻辑还是非常清晰的。如果读者尝试通过阅读源码辅助理解上述内容之后仍然感觉有些眼花缭乱,没有关系,我们在这里暂时做一个简单地梳理,将重点内容梳理到以下几条。

1.  本地Rendezvous和RemoteRendezvous共同继承了同一个接口;

2. RemoteRendezvous需要支持不同的通信协议,因此派生了各种各样的实现类;

3. RemoteRendezvous的使用较为复杂,为此引入了管理器模式——RendezvousMgr,它负责RemoteRendezvous的创建和销毁,并添加了两个额外的Recv接口方便某些场景直接调用;

4. RemoteRendezvous做了两层继承结构只是为了添加一个Initialize方法。

本篇我们梳理使用gRPC协议的部分,从上文中梳理的结构中不难看出,这部分涉及到的类并不多。

1. Rendezvous相关类——RemoteRendezvous,BaseRemoteRendezvous,RpcRemoteRendezvous;

2. 管理器——BaseRendezvousMgr,RpcRendezvousMgr

3. 其他类——BaseRecvTensorCall,RpcRecvTensorCall和DefferedCall

毕竟是涉及到了gRPC协议本身的使用,所以有必要在梳理源码之前从宏观上对gRPC的工作流程做一个简单地梳理。

gRPC编程中的代理模式——Stub与Service

在此我们假设同学们对gRPC的原理和使用有一些基本的了解,比如需要使用Protobuf预先定义Service接口,并且区分Stub和Service等。对此不了解的同学还是建议先认真阅读一下gRPC的使用文档和范例,下面这段文字只对gRPC做一个非常简单的描述。

在一次RPC调用中,客户端需要调用服务端的服务,然后将处理结果返回给客户端。而gRPC做到了“让客户端调用远端函数时就像调用本地函数一样”的体验,这得益于一种经典的设计模式——代理模式。负责为客户端代理的节点(gRPC中称之为Stub)会将请求和参数传到服务端,并由Service进行实际的处理,然后将结果返回给Stub,最终返回到客户端中。我们甚至可以认为负责代理的Stub就是客户端,因为它的职责就是与远端交互并取得结果。另外,为了能够让传输量尽可能少,也为了能够让传输不受客户端和服务端具体的类型限制,gRPC在做跨网络传输前将消息统一序列化成Protobuf格式。下图是从gRPC官网教程中摘出的工作原理图。

Send过程

因为Send过程并不涉及跨进程传输,只是将Ready的Tensor挂入本地Table之中,所以它和LocalRendezvousImpl的Send完全相同。不仅如此,TensorFlow中的任何RemoteRendezvous的Send过程都要遵循这样的原理,基于代码复用的考虑,将这部分内容都被抽象到了公共基类BaseRemoteRendezvous的Send函数里是一个很好的设计。事实上,BaseRemoteRendezvous的Send过程就是调用了LocalRendezvousImpl的Send过程,所以LocalRendezvousImpl必须要作为BaseRemoteRendezvous的成员之一。下面的代码展示了这一过程。

 1 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
 2                                   const Rendezvous::Args& args,
 3                                   const Tensor& val, const bool is_dead) {
 4   VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
 5   {
 6     mutex_lock l(mu_);
 7     if (!status_.ok()) return status_;
 8     DCHECK(is_initialized_locked());
 9     if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
10       return errors::InvalidArgument(
11           "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
12           session_->worker_name);
13     }
14   }
15   // Buffers "val" and "device_context" in local_.
16   return local_->Send(parsed, args, val, is_dead);
17 }

Recv过程

Recv过程就非常复杂了,因为每种RemoteRendezvous都涉及到不同的通信协议以及管理方式,所以Recv函数是真正需要继承重写的模块。在看RpcRemoteRendezvous具体的实现之前,我们必须先将gRPC定义服务的接口部分梳理清楚。

gRPC的服务定义接口文件

在TensorFlow的core/protobuf文件中,我们需要研究一下worker_service.proto文件,这个文件中定义了若干RPC Service接口。

虽然它定义了很多RPC服务接口,但是我们只需要关注和Tensor接收相关的接口定义即可。准确地说,目前我们必须要知道的是下面这个Service定义。

  // See worker.proto for details.
  rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
    // RecvTensor Method
  }

显然,这是一个让服务端处理“接收Tensor”的服务(注意是让服务端处理名为“接收Tensor”的服务,而不是让服务端去接收Tensor。因为客户端有接收Tensor的需求,但需要服务端发送Tensor,为客户端发送Tensor的服务被称之为“接收Tensor”),按照注释提示,我们可以在worker.proto中找到RecvTensorRequest和RecvTensorResponse的数据结构,这部分结构读者可以自己查阅,非常容易理解。在编译时,扩展的Protobuf编译器会对worker_service.proto中的rpc接口生成C++服务接口代码和Stub代码(毕竟Stub代码比较纯粹并且和业务逻辑无关,它只是一个向对应Service端发送处理请求的过程),TensorFlow只需要对具体的Service提供实现即可。

与gRPC生成的代码联系起来

gRPC会为worker_service.proto中每一个rpc服务生成C++接口代码,为了区分多个rpc服务,特意为每个服务生成了特殊的名字。比如RecvTensor服务的名字就是/tensorflow.WorkerService/RecvTensor。为了不直接使用冗长的字符串,TensorFlow为worker_service.proto中的每个服务都做了enumeration的映射,这部分代码在tensorflow/core/distributed_runtime/grpc_worker_service_impl.h和同名实现文件中。

 1 // Names of worker methods.
 2 enum class GrpcWorkerMethod {
 3   kGetStatus,
 4   kCreateWorkerSession,
 5   kDeleteWorkerSession,
 6   kRegisterGraph,
 7   kDeregisterGraph,
 8   kRunGraph,
 9   kCleanupGraph,
10   kCleanupAll,
11   kRecvTensor,
12   kRecvBuf,
13   kLogging,
14   kTracing,
15   kCompleteGroup,
16   kCompleteInstance,
17   kGetStepSequence,
18 };

下面是从enumeration类型映射到具体字符串的函数。

 1 const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
 2   switch (id) {
 3     case GrpcWorkerMethod::kGetStatus:
 4       return "/tensorflow.WorkerService/GetStatus";
 5     case GrpcWorkerMethod::kCreateWorkerSession:
 6       return "/tensorflow.WorkerService/CreateWorkerSession";
 7     case GrpcWorkerMethod::kDeleteWorkerSession:
 8       return "/tensorflow.WorkerService/DeleteWorkerSession";
 9     case GrpcWorkerMethod::kRegisterGraph:
10       return "/tensorflow.WorkerService/RegisterGraph";
11     case GrpcWorkerMethod::kDeregisterGraph:
12       return "/tensorflow.WorkerService/DeregisterGraph";
13     case GrpcWorkerMethod::kRunGraph:
14       return "/tensorflow.WorkerService/RunGraph";
15     case GrpcWorkerMethod::kCleanupGraph:
16       return "/tensorflow.WorkerService/CleanupGraph";
17     case GrpcWorkerMethod::kCleanupAll:
18       return "/tensorflow.WorkerService/CleanupAll";
19     case GrpcWorkerMethod::kRecvTensor:
20       return "/tensorflow.WorkerService/RecvTensor";
21     case GrpcWorkerMethod::kRecvBuf:
22       return "/tensorflow.WorkerService/RecvBuf";
23     case GrpcWorkerMethod::kLogging:
24       return "/tensorflow.WorkerService/Logging";
25     case GrpcWorkerMethod::kTracing:
26       return "/tensorflow.WorkerService/Tracing";
27     case GrpcWorkerMethod::kCompleteGroup:
28       return "/tensorflow.WorkerService/CompleteGroup";
29     case GrpcWorkerMethod::kCompleteInstance:
30       return "/tensorflow.WorkerService/CompleteInstance";
31     case GrpcWorkerMethod::kGetStepSequence:
32       return "/tensorflow.WorkerService/GetStepSequence";
33   }
34   // Shouldn‘t be reached.
35   LOG(FATAL) << "Invalid id: this line shouldn‘t be reached.";
36   return "invalid id";
37 }

另外,还需要为每个RPC服务注册为异步服务,这需要使用gRPC自带的AddMethod接口和MarkMethodAsync接口,如下所示。

1 WorkerService::AsyncService::AsyncService() {
2   for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
3     AddMethod(new ::grpc::internal::RpcServiceMethod(
4         GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
5         ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
6     ::grpc::Service::MarkMethodAsync(i);
7   }
8 }

好了,接下来就是解析源码中具体的交互过程了。其实TensorFlow在框架层面对gRPC的使用了一些Best Practice,比如异步处理请求的架构和多线程轮询Completion Queue等。将这些连在一起梳理需要更多的篇幅,一次性展示大量的内容也不利于阅读,所以我们只对发送和接收过程做一个梳理。

Client端的调用链

从BaseRemoteRendeezvous的RecvAsync出发,逐渐深入调用链底层。时序图是分析调用链的最好工具,下面给出了Client端到Stub的调用过程,这里面涉及到了几个新的类。

1. RpcRecvTensorCall:这是一次gRPC调用的抽象,继承了BaseRecvTensorCall这个抽象基类,它封装了复杂的后续调用链。

2. GrpcRemoteWorker:它也是client端的内容,只不过它是Remote端的代理。

3. RpcState:这是真正封装了一次RPC调用及状态的类,它会直接对Stub以及GenericClientAsyncResponseReader进行管理,比如向服务端发送异步请求并等待结果等。

Client端是一个虚拟角色,它可以是调用RpcRemoteRendezvous的任何一个模块。我们可以看到,RpcRemoteRendezvous的一次RecvRemoteAsync过程非常长,并且Stub的调用时异步的。这里的代码确实有些多,所以我们只展示一下关键代码段,但是建议读者打开源码仔细阅读每个调用链。

下面是RecvRemoteAsync的代码段,主要做了RpcRecvTensorCall的初始化,注册以及启动工作。

 1 void RpcRemoteRendezvous::RecvFromRemoteAsync(
 2     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
 3     DoneCallback done) {
 4   CHECK(is_initialized());
 5   Status s;
 6
 7   // Prepare a RecvTensor call that can handle being aborted.
 8   RpcRecvTensorCall* call = get_call_freelist()->New();
 9
10   // key.src_device identifies a remote device.
11   if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
12                                         &call->src_rel_device_)) {
13     s = errors::Internal(parsed.src_device,
14                          " is invalid remote source device.");
15   }
16   WorkerSession* sess = session();
17   WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
18   if (s.ok() && rwi == nullptr) {
19     s = errors::Internal("No worker known as ", call->src_worker_);
20   }
21
22   Device* dst_device;
23   if (s.ok()) {
24     s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
25   }
26   if (!s.ok()) {
27     if (rwi != nullptr) {
28       sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
29     }
30     get_call_freelist()->Release(call, sess->worker_cache.get());
31     done(s, Args(), recv_args, Tensor{}, false);
32     return;
33   }
34
35   call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
36              recv_args, std::move(done));
37
38   // Record "call" in active_ so that it can be aborted cleanly.
39   RegisterCall(call);
40
41   // RendezvousMgr already aborted, shouldn‘t send RPC call any more
42   if (!call->status().ok()) {
43     call->done()(call->status(), Args(), Args(), Tensor(), false);
44     session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
45     call->wi_ = nullptr;
46     get_call_freelist()->Release(call, session()->worker_cache.get());
47     return;
48   }
49
50   // Start "call".
51   Ref();
52   call->Start([this, call]() {
53     // Removes "call" from active_. Prevent StartAbort().
54     DeregisterCall(call);
55     // If StartAbort was called prior to DeregisterCall, then the
56     // current status should be bad.
57     Status s = call->status();
58     call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
59     session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
60     call->wi_ = nullptr;
61     get_call_freelist()->Release(call, session()->worker_cache.get());
62     Unref();
63   });
64 }

下面是GrpcRemoteWorker调用RPCState的过程,最后的IssueRequest即开始创建RPCState并触发stub的调用。

void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
                       TensorResponse* response, StatusCallback done) override {
    VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
    int64 start_usec = Env::Default()->NowMicros();
    // Type-specialized logging for this method.
    bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
    StatusCallback wrapper_done;
    const StatusCallback* cb_to_use;
    if (!logging_active) {
      cb_to_use = &done;  // No additional work to do, so just use done directly
    } else {
      wrapper_done = [this, request, response, done, start_usec](Status s) {
        if (logger_->LoggingActive()) {
          int64 end_usec = Env::Default()->NowMicros();
          int64 step_id = request->step_id();
          int64 bytes = response->tensor().TotalBytes();
          int64 send_start_usec = start_usec;
          // If a send start time was reported by the other side, use
          // that instead.  Maybe we should mark the display if we‘re using
          // our local time instead of the remote start time?
          if (response->metadata().send_start_micros()) {
            // send_start_micros is the timestamp taken when the
            // remote machine began to send the RecvTensor response.
            // Due to clock skew between source and dest machines, it
            // is possible that send_start_micros can be larger than
            // end_usec or less than start_usec.
            //
            // To respect causality, we enforce the invariants that
            // the RecvTensor response can not have been sent before
            // the RecvTensor request, and must have been sent before
            // it was received.
            send_start_usec = std::max(
                start_usec,
                static_cast<int64>(response->metadata().send_start_micros()));
            send_start_usec = std::min(send_start_usec, end_usec - 1);
          }
          const string& key = request->rendezvous_key();
          std::vector<string> key_parts = str_util::Split(key, ‘;‘);
          if (key_parts.size() != 5) {
            LOG(WARNING) << "Bad key: " << key;
          } else {
            logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
                                      key_parts[3],  // tensor name
                                      key_parts[0],  // src_device
                                      key_parts[2],  // dst_device
                                      bytes);
          }
        }
        VLOG(2) << "done callback, req: " << request->DebugString()
                << " response " << response->metadata().DebugString();
        done(s);
      };
      cb_to_use = &wrapper_done;
    }

    IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
  }

最后展示一下Stub的触发位置,这个函数在RPCState类中,并且在创建RPCState对象时立即被调用。

 1 void StartCall() {
 2     context_.reset(new ::grpc::ClientContext());
 3     context_->set_fail_fast(fail_fast_);
 4
 5     if (timeout_in_ms_ > 0) {
 6       context_->set_deadline(
 7           gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
 8     }
 9     if (call_opts_) {
10       call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
11     }
12
13     VLOG(2) << "Starting call: " << method_;
14
15     call_ = std::move(
16         stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_));
17     call_->StartCall();
18     call_->Finish(&response_buf_, &status_, this);
19   }

Server端负责查找Tensor的Service

如果我们把异步处理请求的架构和多线程轮询Completion Queue的Best Practice去除,那么Service端其实并不复杂,调用链相对Client端短了很多,下面的时序图展示了自Server端接收请求后的调用过程,这里面也涉及到了几个新的类。

1. GrpcWorkerServiceThread:这是服务端处理请求的线程类。

2. GrpcWorker:这是真正负责处理请求的Worker,是GrpcRemoteWorker的服务端版本;

3. WorkerCall:这是服务端处理一次gRPC请求和响应的类,抽象为WorkerCall,其实这也是个别名,真实的名称较长;

4. ServerAsyncResponseWriter:这是gRPC为用户端提供的Response writer,是承载响应的实体。

5. Utils:这其实不是一个类,而是多个工具的组合,为了在时序图表达方便,统称为Utils。

可以看出,服务端接收到请求后,会调用RecvLocalAsync在本地将客户端所需要的Tensor查找出来,然后拷贝到CPU上,最后利用gRPC发送回客户端。同样,我们展示关键代码段。

下面是GrpcWorker调用RendezvousMgr的RecvLocalAsync为客户端寻找真正Tensor的过程。回调函数中能够看出,在找到对应Tensor后,需要将Tensor做Encode,然后拷贝到CPU端。

 1  env_->rendezvous_mgr->RecvLocalAsync(
 2       step_id, parsed,
 3       [opts, response, done, src_dev, request](
 4           const Status& status, const Rendezvous::Args& send_args,
 5           const Rendezvous::Args& recv_args, const Tensor& val,
 6           const bool is_dead) {
 7         opts->ClearCancelCallback();
 8         if (status.ok()) {
 9           // DMA can only be used for Tensors that do not fall into
10           // the following three odd edge cases: 1) a zero-size
11           // buffer, 2) a dead tensor which has an uninit value, and
12           // 3) the tensor has the on_host allocation attribute,
13           // i.e. it‘s in CPU RAM *independent of its assigned
14           // device type*.
15           const bool on_host = send_args.alloc_attrs.on_host();
16           {
17             // Non-DMA cases.
18             if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
19               DeviceContext* send_dev_context = send_args.device_context;
20               AllocatorAttributes alloc_attrs;
21               alloc_attrs.set_gpu_compatible(true);
22               alloc_attrs.set_on_host(true);
23               Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
24               Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
25               CHECK(send_dev_context)
26                   << "send dev name: " << src_dev->name()
27                   << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
28               // "val" is on an accelerator device. Uses the device_context to
29               // fill the copy on host.
30               StatusCallback copy_ready = [response, done, copy,
31                                            is_dead](const Status& s) {
32                 // The value is now ready to be returned on the wire.
33                 grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
34                 done(s);
35                 delete copy;
36               };
37
38               send_dev_context->CopyDeviceTensorToCPU(
39                   &val, request->rendezvous_key(), src_dev, copy, copy_ready);
40             } else {
41               grpc::EncodeTensorToByteBuffer(is_dead, val, response);
42               done(Status::OK());
43             }
44           }
45         } else {
46           //  !s.ok()
47           done(status);
48         }
49       });

至此,我们的Rendezvous之gRPC传输之旅就圆满结束了,在阅读本篇时还是希望读者能够在理解结构设计后,对照C++源码仔细阅读反复推敲里面的每一个细节,这样才能有更深的理解。

一个需要思考的问题——gRPC传输Tensor很低效?

是的,确实很低效。为什么?从设计哲学上说,gRPC本身设计并不适合深度学习训练场景。从细节上来说它有以下几个缺陷:

1. gRPC发送Tensor前,接收Tensor后必须要做序列化,在Tensor很大的时候这是一个非常讨厌的overhead,发送接收延迟过大;

2. 序列化根本没有对数据做任何压缩,这是因为Tensor都是稠密的,所以序列化没有意义;

3. 不能支持RDMA和GPU Direct。虽然这依赖于硬件,但是gRPC在软件层面也并没有做这些适配。

所以大部分人使用TensorFlow分布式时都会对性能有很大的抱怨,这里面很大的原因和gRPC有关。如果你使用NCCL或者MPI,那么你会得到不一样的性能。

总结

本篇文章篇幅较长,是Rendezvous机制系列的第二篇,主要梳理了涉及到gRPC传输的模块架构设计和源码细节,并且详细梳理了通信过程。理解TensorFlow跨机传输的关键在于理解一个事实:真正的通信过程由Recv方触发,而不是Send方!Send依然将Ready的Tensor挂入本地Table中,而Recv会向Send端发送gRPC请求查询所需要的Tensor,然后返回所需要的结果,这个过程虽然有些别扭,但逻辑上并不稀奇。从结构设计上来说,RemoteRendezvous沿用了Rendezvous接口,并且完全复用了LocalRendezvousImpl的Send代码,而Recv由于涉及到具体的通信细节和管理机制,则各有各的不同。另外,RemoteRendezvous相对LocalRendezvous复杂很多,需要管理器进行管理。最后一大部分是Send和Recv的源码细节展示,因为无论是客户端还是服务端,其调用链都比较长,所以以时序图的形式展示各个类之间的调用关系和协作关系较为清晰,具体每个调用的细节建议读者结合源码逐一分析,并连同本篇文章一起理解较为深刻。最后,我们总结了gRPC传输Tensor的明显缺陷,当然这也是为性能优化开辟了新的空间。

原文地址:https://www.cnblogs.com/deep-learning-stacks/p/10355770.html

时间: 2024-10-11 17:12:52

TensorFlow中的通信机制——Rendezvous(二)gRPC传输的相关文章

TensorFlow中的通信机制——Rendezvous(一)本地传输

背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 在TensorFlow源码中我们经常能看到一个奇怪的词--Rendezvous.如果从仔细统计该单词出现的频率和模块,你会发现无论在单机还是分布式,无论在core目录还是contrib目录都存在它的身影,所涉及的模块非常多.Rendezvous是一个法语单词,发音也比较特殊,一般直译为"约会.相会.会和",而在TensorFlow中,Rendezvous是用来完成消

.Net中Remoting通信机制简单实例

.Net中Remoting通信机制 前言: 本程序例子实现一个简单的Remoting通信案例 本程序采用语言:c# 编译工具:vs2013工程文件 编译环境:.net 4.0 程序模块: Test测试 Talker Server端 Client端 源代码工程文件下载 Test测试程序截图: Talker类: 1 public class Talker : MarshalByRefObject 2 { 3 public void Talk(string word) 4 { 5 System.Con

.Net中Remoting通信机制

Remoting通信机制 Remoting介绍 主要元素 通道类型 激活方式 对象定义 Remoting介绍 什么是Remoting,简而言之,我们可以将其看作是一种分布式处理方式. 从微软的产品角度来看,可以说Remoting就是DCOM(分布式组件对象模型,分布式组件对象模式)的一种升级,它改善了很多功能,并极好的融合到.Net平台下.Microsoft .NET Remoting 提供了一种允许对象通过应用程序域与另一对象进行交互的框架.这也正是我们使用Remoting的原因.为什么呢?在

Android中的常见通信机制和Linux中的通信机制

Handler Handler是Android系统中的一种消息传递机制,起作用是应对多线程场景.将A进程的消息传递给B线程,实现异步消息处理.很多情况是将工作线程中需要更新UI的操作消息传递给UI主线程,而实现更新UI操作. 因为工作线程和主线程是共享地址空间,即Handler实例对象mHandler位于线程间共享的内存堆上,工作线程和主线程直接使用该对象,只需要注意多线程的同步问题.工作系统通过mHandler向其成员变量MessageQueue中添加Message,而主线程一直处于loop中

浅谈Linux中的信号机制(二)

首先谢谢 @小尧弟 这位朋友对我昨天夜里写的一篇<浅谈Linux中的信号机制(一)>的指正,之前的题目我用的“浅析”一词,给人一种要剖析内核的感觉.本人自知功力不够,尚且不能对着Linux内核源码评头论足.以后的路还很长,我还是一步一个脚印的慢慢走着吧,Linux内核这座山,我才刚刚抵达山脚下. 好了,言归正传,我接着昨天写下去.如有错误还请各位看官指正,先此谢过. 上篇末尾,我们看到了这样的现象:send进程总共发送了500次SIGINT信号给rcv进程,但是实际过程中rcv只接受/处理了1

Android中AIDL通信机制分析

一.背景 ·1.AIDL出现的原因 在android系统中,每一个程序都是运行在自己的进程中,进程之间无法进行通讯,为了在Android平台,一个进程通常不能访问另一个进程的内存空间,所以要想对话,需要将对象分解成操作系统可以理解的基本单元,并且有序的通过进程边界.通过代码来实现这个数据传输过程是冗长乏味的,Android提供了AIDL工具来处理这项工作,实现IPC(进行间的通信)与J2e中的RMI类似. ·2.绑定service 我看了很多人都博客都没有说到这里,其实我个人感觉AIDL就是一个

android ipc通信机制之二序列化接口和Binder

IPC的一些基本概念,Serializable接口,Parcelable接口,已经Binder.此核心为最后的IBookManager.java类!!! Serializable接口,Parcelable接口都是可以完成对象的序列化过程. 序列化 (Serialization)将对象的状态信息转换为可以存储或传输的形式的过程.在序列化期间,对象将其当前状态写入到临时或持久性存储区.以后,可以通过从存储区中读取或反序列化对象的状态,重新创建该对象. 两者均可以实现序列化并且都可以用于Intent数

通过了解Servlet和Http之间的关系,了解web中http通信使用(二)

注:图片如果损坏,点击文章链接:https://www.toutiao.com/i6512399401825075719/ 上一节,简单理解"请求服务"的内容:http协议中的请求,接下来我们再看下http协议中的响应 http协议中的响应 Http响应和Http请求一样,也是有响应的格式 ? ? 细化一下: 请求 响应 实际中是什么样子呢? 我们把我们之前的代码稍微改动下,方便观察 然后我们打开浏览器,输入地址 然后按F12,出现如下界面 然后点提交 里面的内容基本上就是http协议

Android中对消息机制(Handler)的再次解读

今天遇到一些关于在子线程中操作Handler的问题,感觉又要研究源代码了,但是关于Handler的话,我之前研究过,可以参考这篇文章:http://blog.csdn.net/jiangwei0910410003/article/details/17021809.但是这篇文章没有说的那么深入了,所以这次就更深入的解读一下. 摘要 Android中的应用程序都是通过消息驱动的,系统为每一个应用程序维护一个消息队列(MessageQueue),应用程序的主线程不断的从这个消息队列中获取消息(Loop