Before detailing the various strategies of TensorFlow distribution, we first need to look at the foundation of distribution: the distributed environment. Only lay a solid foundation, in order to the greatest extent in the later analysis of the work to clear away obstacles, get twice the result with half the effort. This article introduces the static architecture of Worker (a set of related concepts).

The other articles in this series are:

Heterogeneous Distribute Learning based on TensorFlow distributed thesis [翻译

Implementation of Control Flow in TensorFlow

TensorFlow Distributed environment (1) — overall architecture

TensorFlow distributed environment (2)– Master static logic

1. Inheritance

1.1 Role Concepts

The TensorFlow Worker class is an entity that performs calculations, and its main functions are:

  • Receive requests from the Master.
  • Management WorkerSession.
  • Process the registered subgraph, such as secondary splitting the subgraph according to the device condition on its own node.
  • Run the registered subgraph on each device.
  • Support worker-to-worker tensor transfers, etc. For example, cudaMemcpyAsync is used between CPU and GPU, DMA is used between local GPU, and gRPC or RDMA is used for remote worker.
  • After the execution, take out the result from sink, the termination node of the calculation graph.

See protobuf/worker_service.proto for more details on each method.

1.2 interface

Access to the WorkerService is done through the WorkerInterface. WorkerInterface is the interface class of worker, which is the interface to interact with TensorFlow worker service, mainly including:

  • Asynchronous virtual functions are defined, such as CreateWorkerSessionAsync, which are implemented by derived classes. These virtual functions correspond to GrpcWorkerMethod supported by GrpcWorkerService, as well as Protobuf configurations.
  • Defines some synchronization function, such as CreateWorkerSession, it will be a similar CallAndWait (& ME: : CreateWorkerSessionAsync, request, response) to invoke the asynchronous virtual function to the concrete.

1.3 WorkerInterface Derived class

As shown in the figure below, there are three implementations of WorkerInterface.

  • Worker: This class can be subclassed to provide specialized implementations of specific methods for different transport mechanisms. For example, GrpcWorker specifically implements the RecvTensorAsync() method to support more efficient gRPC data structures for handling large binary data.
  • GrpcWorker: Again derived from Worker, is the Worker role in local mode. If both Master and Worker are local, they can be called directly without RPC network transmission.
  • GrpcRemoteWorker: In distributed mode, the Worker is located at the remote end, and the local user needs to use GrpcRemoteWorker to access the remote Worker.
    • GrpcRemoteWorker is a gRPC client, which accesses GrpcWorkerService on the remote Worker through the stub.
    • GrpcWorkerService implements all interfaces defined by WorkerService, but the actual work is forwarded to the local GrpcWorker.

Specific examples are as follows:

FIG. 1 Worker logic relationship

2. GrpcRemoteWorker

GrpcRemoteWorker is equivalent to a local proxy for the remote Worker.

  • The local Master will partition the calculation graph, and then call the local Worker or GrpcRemoteWorker to execute the sub-calculation graph of the partition, depending on whether the partition is not local or remote.
  • Local GrpcRemoteWorker generation is in tensorflow/core/distributed_runtime/RPC/grpc_worker_cache GetOrCreateWorker cc.
  • GrpcRemoteWorker sends GRPC requests to the remote end via IssueRequest.
  • After receiving the request, the remote GrpcWorkerService daemon calls the local Worker to process the request and returns the result after completion.

2.1 define

GrpcRemoteWorker code is as follows, we omit some code, such as the implementation of DeleteWorkerSessionAsync method, etc.

class GrpcRemoteWorker : public WorkerInterface {
 public:
  explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                            ::grpc::CompletionQueue* completion_queue,
                            thread::ThreadPool* callback_threadpool,
                            WorkerCacheLogger* logger, const string& target)
      : channel_(std::move(channel)),
        stub_(channel_),
        cq_(completion_queue),
        callback_threadpool_(callback_threadpool),
        getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
        createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
        deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
        registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
        deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
        rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
        cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
        cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
        recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
        recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
        logging_(Method(GrpcWorkerMethod::kLogging)),
        tracing_(Method(GrpcWorkerMethod::kTracing)),
        completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
        instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
        getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
        markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
        logger_(logger),
        target_(target) {}

  ~GrpcRemoteWorker(a)override {}

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }

  void RegisterGraphAsync(const RegisterGraphRequest* request,
                          RegisterGraphResponse* response,
                          StatusCallback done) override {
    IssueRequest(request, response, registergraph_, std::move(done));
  }

  void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
                     RunGraphResponse* response, StatusCallback done) override {
    IssueRequest(request, response, rungraph_, std::move(done), call_opts);
  }
  void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request, MutableRunGraphResponseWrapper* response, StatusCallback done) override {
    IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
                 rungraph_, std::move(done), call_opts);
  }

 private:
  // Utility method for issuing a generic asynchronous request. The
  // given callback, done, will be called when the RPC completes.
  void IssueRequest(const protobuf::Message* request,
                    protobuf::Message* response, const ::grpc::string& method,
                    StatusCallback done, CallOptions* call_opts = nullptr.bool fail_fast = true) {
    new RPCState<protobuf::Message>(
        &stub_, cq_, method, *request, response, std::move(done), call_opts,
        callback_threadpool_, MaxRetries(), fail_fast, &target_);
  }

  void IssueRequest(const protobuf::Message* request, TensorResponse* response,
                    const ::grpc::string& method, StatusCallback done,
                    CallOptions* call_opts = nullptr) {
    new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
                                 std::move(done), call_opts,
                                 callback_threadpool_, MaxRetries(),
                                 /*fail_fast=*/true, &target_);
  }

  // Helper function for initializing the RpcMethod objects below.
  const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }

  // Helper function for configuring max GRPC retries. Defaults to 0 (no
  // retries).
  const int64_t MaxRetries(a) {
    int64_t max_retries = - 1;
    TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES".0, &max_retries));
    return max_retries;
  }

  SharedGrpcChannelPtr channel_;
  ::grpc::GenericStub stub_;
  ::grpc::CompletionQueue* cq_;
  thread::ThreadPool* callback_threadpool_;

  const ::grpc::string getstatus_;
  const ::grpc::string createworkersession_;
  const ::grpc::string deleteworkersession_;
  const ::grpc::string registergraph_;
  const ::grpc::string deregistergraph_;
  const ::grpc::string rungraph_;
  const ::grpc::string cleanupgraph_;
  const ::grpc::string cleanupall_;
  const ::grpc::string recvtensor_;
  const ::grpc::string recvbuf_;
  const ::grpc::string logging_;
  const ::grpc::string tracing_;
  const ::grpc::string completegroup_;
  const ::grpc::string instancesource_;
  const ::grpc::string getstepsequence_;
  const ::grpc::string markrecvfinished_;

  // Support for logging.
  WorkerCacheLogger* logger_;
  const string target_;

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
};
Copy the code

2.2 to generate

The generated code is as follows:

WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
                                     ::grpc::CompletionQueue* completion_queue,
                                     thread::ThreadPool* callback_threadpool,
                                     WorkerCacheLogger* logger,
                                     const string& target) {
  return new GrpcRemoteWorker(std::move(channel), completion_queue,
                              callback_threadpool, logger, target);
}
Copy the code

Specific call is in the cache, the code is located in: tensorflow/core/distributed_runtime/RPC/grpc_worker_cache. Cc, it is on the basis of parameters decided to produce what kind of Worker.

WorkerInterface* GetOrCreateWorker(const string& target) override {
  if (target == local_target_) {
    return local_worker_;
  } else {
    SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
    if(! channel) {return nullptr;
    }
    size_t index = AssignWorkerToThread(target);
    return NewGrpcRemoteWorker(
        channel, worker_env_->GetCompletionQueue(index),
        worker_env_->GetThreadPool(), &logger_, target); }}Copy the code

2.3 Sending a Request

Let’s look at how to send a request. CreateWorkerSessionAsync actually sends a request for createWorkersession_.

  void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                CreateWorkerSessionResponse* response,
                                StatusCallback done) override {
    IssueRequest(request, response, createworkersession_, std::move(done));
  }
Copy the code

IssueRequest is in the definition above, relisted below, and you can see that the remote method is called, which in our case is createWorkerssession _.

void IssueRequest(const protobuf::Message* request,
                  protobuf::Message* response, const ::grpc::string& method,
                  StatusCallback done, CallOptions* call_opts = nullptr.bool fail_fast = true) {
  new RPCState<protobuf::Message>(
      &stub_, cq_, method, *request, response, std::move(done), call_opts,
      callback_threadpool_, MaxRetries(), fail_fast, &target_);
}
Copy the code

Createworkersession_ is configured in the builder function.

explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                          ::grpc::CompletionQueue* completion_queue,
                          thread::ThreadPool* callback_threadpool,
                          WorkerCacheLogger* logger, const string& target): channel_ (STD: : move (channel), createworkersession_ (Method (GrpcWorkerMethod: : kCreateWorkerSession)), / / configurationCopy the code

GrpcWorkerMethodName defined in tensorflow/core/distributed_runtime/RPC/grpc_worker_service_impl. Cc, here is a specific string, Namely distal GrpcWorker method name, you can see, CreateWorkerSessionAsync actually call is the “. / tensorflow WorkerService/CreateWorkerSession “.

// Names of worker methods.
enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }
  // Shouldn't be reached.
  LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
  return "invalid id";
}
Copy the code

3. Worker Service

WorkerService is a gRPC service that defines a TensorFlow service. WorkerService performs data flow diagrams on a set of local devices on behalf of MasterService. A WorkerService keeps track of multiple “registered computations.” Each registry diagram is a subgraph of the customer calculation diagram, corresponding only to nodes that should be executed on the worker (and any additional nodes that are needed for interprocess communication using the RecvTensor method).

The Master will search for other Server instances in the cluster based on the ClusterSpec content, and then use these Server instances as Worker roles. The Master then distributes the subgraph to these Worker nodes, and then arranges these workers to complete the calculation process of the specific subgraph. If there is data dependence between workers, they interact through inter-process communication. Whether the Master calls the Worker or workers access each other, they must follow the interface specification defined by WorkerService. All interfaces to WorkerService are defined in the worker_service.proto file.

service WorkerService {
  // See worker.proto for details.
  rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);

  // See worker.proto for details.
  rpc CreateWorkerSession(CreateWorkerSessionRequest)
      returns (CreateWorkerSessionResponse);

  // See worker.proto for details.
  rpc DeleteWorkerSession(DeleteWorkerSessionRequest)
      returns (DeleteWorkerSessionResponse);

  // See worker.proto for details.
  rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);

  // See worker.proto for details.
  rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);

  // See worker.proto for details.
  rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);

  // See worker.proto for details.
  rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);

  // See worker.proto for details.
  rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);

  // See worker.proto for details.
  rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
    // RecvTensor Method
  }

  // See worker.proto for details.
  rpc Logging(LoggingRequest) returns (LoggingResponse);

  // See worker.proto for details.
  rpc Tracing(TracingRequest) returns (TracingResponse);

  // See worker.proto for details.
  rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {}

  // See worker.proto for details.
  rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);

  // See worker.proto for details.
  rpc CompleteGroup(CompleteGroupRequest) returns (CompleteGroupResponse);

  // See worker.proto for details.
  rpc CompleteInstance(CompleteInstanceRequest)
      returns (CompleteInstanceResponse);
}
Copy the code

3.3.1 WorkerInterface

Similar to MasterService, access to WorkerService is done through the WorkerInterface. WorkerInterface is the interface class of worker, which is the interface to interact with TensorFlow worker service, mainly including:

  • Asynchronous virtual functions are defined, such as CreateWorkerSessionAsync, which are implemented by derived classes. These virtual functions correspond to GrpcWorkerMethod supported by GrpcWorkerService, as well as Protobuf configurations.
  • Defines some synchronization function, such as CreateWorkerSession, it will be a similar CallAndWait (& ME: : CreateWorkerSessionAsync, request, response) method to invoke the asynchronous virtual function to the concrete.

We first list its asynchronous interfaces as follows.

// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
 public:
  virtual void GetStatusAsync(CallOptions* opts,
                              const GetStatusRequest* request,
                              GetStatusResponse* response, bool fail_fast,
                              StatusCallback done) = 0;

  virtual void CreateWorkerSessionAsync(
      const CreateWorkerSessionRequest* request,
      CreateWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void DeleteWorkerSessionAsync(
      CallOptions* opts, const DeleteWorkerSessionRequest* request,
      DeleteWorkerSessionResponse* response, StatusCallback done) = 0;

  virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
                                  RegisterGraphResponse* response,
                                  StatusCallback done) = 0;

  virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
                                    DeregisterGraphResponse* response,
                                    StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, MutableRunGraphResponseWrapper* response, StatusCallback done) = 0;

  virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
                             RunGraphResponse* response, StatusCallback done) {
    RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
    MutableRunGraphResponseWrapper* wrapped_response =
        new NonOwnedProtoRunGraphResponse(response);
    RunGraphAsync(opts, wrapped_request, wrapped_response,
                  [wrapped_request, wrapped_response,
                   done = std::move(done)](const Status& s) {
                    done(s);
                    delete wrapped_request;
                    delete wrapped_response;
                  });
  }

  virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
                                 CleanupGraphResponse* response,
                                 StatusCallback done) = 0;

  virtual void CleanupAllAsync(const CleanupAllRequest* request,
                               CleanupAllResponse* response,
                               StatusCallback done) = 0;

  virtual void RecvTensorAsync(CallOptions* opts,
                               const RecvTensorRequest* request,
                               TensorResponse* response,
                               StatusCallback done) = 0;

  virtual void LoggingAsync(const LoggingRequest* request,
                            LoggingResponse* response, StatusCallback done) = 0;

  virtual void TracingAsync(const TracingRequest* request,
                            TracingResponse* response, StatusCallback done) = 0;

  virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                            RecvBufResponse* response, StatusCallback done) = 0;

  virtual void CompleteGroupAsync(CallOptions* opts,
                                  const CompleteGroupRequest* request,
                                  CompleteGroupResponse* response,
                                  StatusCallback done) = 0;

  virtual void CompleteInstanceAsync(CallOptions* ops,
                                     const CompleteInstanceRequest* request,
                                     CompleteInstanceResponse* response,
                                     StatusCallback done) = 0;

  virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
                                    GetStepSequenceResponse* response,
                                    StatusCallback done) = 0;
}
Copy the code

The WorkerInterface also provides a synchronization interface so that the Master or Worker can call the remote WorkerService’s methods as if they were local functions. Synchronous interfaces are implemented on top of asynchronous interfaces and encapsulate asynchrony using the CallAndWait adapter. In addition, some restrictions are made to prevent external code from illegally deleting the WorkerInterface instance. For example, its destructor is protected, making the WorkerCacheInterface a friend. And by WorkerCacheInterface: : ReleaseWorker is responsible for deleting WorkerInterface instance. Below is the synchronization interface and some basic functions, member variables.

// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
 public:

  virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest(a) {
    return new MutableProtoRunGraphRequest;
  }

  virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse(a) {
    return new OwnedProtoRunGraphResponse;
  }

  Status GetStatus(const GetStatusRequest* request,
                   GetStatusResponse* response) {
    Status ret;
    Notification n;
    GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
                   [&ret, &n](const Status& s) {
                     ret = s;
                     n.Notify(a); }); n.WaitForNotification(a);return ret;
  }

  Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
                             CreateWorkerSessionResponse* response) {
    return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
  }

  Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
                             DeleteWorkerSessionResponse* response) {
    return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
                                  response);
  }

  Status RegisterGraph(const RegisterGraphRequest* request,
                       RegisterGraphResponse* response) {
    return CallAndWait(&ME::RegisterGraphAsync, request, response);
  }

  Status DeregisterGraph(const DeregisterGraphRequest* request,
                         DeregisterGraphResponse* response) {
    return CallAndWait(&ME::DeregisterGraphAsync, request, response);
  }

  Status CleanupGraph(const CleanupGraphRequest* request,
                      CleanupGraphResponse* response) {
    return CallAndWait(&ME::CleanupGraphAsync, request, response);
  }

  Status CleanupAll(const CleanupAllRequest* request,
                    CleanupAllResponse* response) {
    return CallAndWait(&ME::CleanupAllAsync, request, response);
  }

  Status Logging(const LoggingRequest* request, LoggingResponse* response) {
    return CallAndWait(&ME::LoggingAsync, request, response);
  }

  Status Tracing(const TracingRequest* request, TracingResponse* response) {
    return CallAndWait(&ME::TracingAsync, request, response);
  }

  Status GetStepSequence(const GetStepSequenceRequest* request,
                         GetStepSequenceResponse* response) {
    return CallAndWait(&ME::GetStepSequenceAsync, request, response);
  }

 protected:
  // Instances of WorkerInterface must be deleted by a call to
  // WorkerCacheInterface::ReleaseWorker().
  virtual ~WorkerInterface() {}
  friend class WorkerCacheInterface;

  // NOTE: This should only be called by implementations of this
  // interface whose CreateRunGraphResponse() method returns a
  // proto-based wrappers for the RunGraphResponse message.
  RunGraphResponse* get_proto_from_wrapper( MutableRunGraphResponseWrapper* wrapper) {
    return wrapper->get_proto(a); }private:
  typedef WorkerInterface ME;

  template <typename Method, typename Req, typename Resp>
  Status CallAndWait(Method func, const Req* req, Resp* resp) {
    Status ret;
    Notification n;
    (this->*func)(req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify(a); }); n.WaitForNotification(a);return ret;
  }

  template <typename Method, typename Req, typename Resp>
  Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
    CallOptions call_opts;
    Status ret;
    Notification n;
    (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
      ret = s;
      n.Notify(a); }); n.WaitForNotification(a);returnret; }};Copy the code

3.3.2 Concept sorting

There are many concepts involved in the WorkerService interface that we need to comb through.

As mentioned earlier, Client and Master collaborate through the session_handle/MasterSession pair, Master and Worker cooperate through MasterSession and WorkerSession. MasterSession manages multiple subordinate workersessions in a unified manner. There are several concepts that need to be clarified:

  • Session_handle: The purpose is to allow MasterSession to centrally manage multiple workersessions under it. This corresponds to MasterSession and is generated when MasterSession is created. Through CreateSessionResponse back to the Client, through CreateWorkerSessionRequest send Worker, so from the Client to the Master, The link to Worker is uniquely identified by session_handle.
  • Graph_handle: When a subgraph is registered, it is generated by GraphMgr::Register and returned to the Master via RegisterGraphResponse. The subgraph is identified by the graph_handle. Within the cluster, the (session_handle, graph_handle) tuple uniquely identifies a subgraph.
  • Step_id: Since the Master will allow multiple workers to perform calculations concurrently, it will broadcast announcements to inform everyone to execute RunGraph. In order to distinguish different steps, the Master generates globally unique identifier step_id for each RunStep. The step_id is carried to the Worker via the RunGraphRequest message.

Let’s tease out graph_handle. GraphMgr::Register generates graph_handle.

Status GraphMgr::Register(
    const string& handle, const GraphDef& gdef, WorkerSession* session,
    const GraphOptions& graph_options, const DebugOptions& debug_options,
    const ConfigProto& config_proto, int64_t collective_graph_key,
    DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
  Item* item = new Item;
  Status s = InitItem(handle, gdef, session, graph_options, debug_options,
                      config_proto, collective_graph_key, cluster_flr, item);
  // Inserts one item into table_.
  {
    mutex_lock l(mu_);
    *graph_handle =
        strings::Printf("%016llx".static_cast<long long>(++next_id_));
    item->handle = *graph_handle;
    CHECK(table_.insert({*graph_handle, item}).second);
  }
  return Status::OK(a); }Copy the code

RegisterGraphResponse will return graph_handle to the Master.

message RegisterGraphResponse {
  // If the registration succeeds, returns an opaque graph_handle to
  // the master. The master calls RunGraph with graph_handle to
  // compute different steps.
  string graph_handle = 1;
}

Copy the code

And the split subgraph has graph_handle in it.

// Graph partitioned into per-location subgraphs.
struct Part {
  // Worker name.
  string name;

  // Maps feed names to rendezvous keys. Empty most of the time.
  std::unordered_map<string, string> feed_key;

  // Maps rendezvous keys to fetch names. Empty most of the time.
  std::unordered_map<string, string> key_fetch;

  // The interface to the worker. Owned.
  WorkerInterface* worker = nullptr;

  // After registration with the worker, graph_handle identifies
  // this partition on the worker.
  string graph_handle;

  Part() : feed_key(3), key_fetch(3) {}};Copy the code

The graph_handle is set to the subgraph when the registration returns.

Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
    const PartitionOptions& popts,
    std::unordered_map<string, GraphDef> graph_partitions) {
  partitions_.reserve(graph_partitions.size());
  Status s;
  for (auto& name_def : graph_partitions) {
    partitions_.emplace_back(a); Part* part = &partitions_.back(a); part->name = name_def.first;TrackFeedsAndFetches(part, name_def.second, popts);
    part->worker = worker_cache_->GetOrCreateWorker(part->name);
    if (part->worker == nullptr) {
      s = errors::NotFound("worker ", part->name);
      break; }}if(! s.ok()) {
    for (Part& part : partitions_) {
      worker_cache_->ReleaseWorker(part.name, part.worker);
      part.worker = nullptr;
    }
    return s;
  }
  struct Call {
    RegisterGraphRequest req;
    RegisterGraphResponse resp;
    Status status;
  };
  const int num = partitions_.size(a);gtl::InlinedVector<Call, 4> calls(num);
  BlockingCounter done(num);
  for (int i = 0; i < num; ++i) {
    const Part& part = partitions_[i];
    Call* c = &calls[i];
    c->req.set_session_handle(session_handle_);
    c->req.set_create_worker_session_called(! should_deregister_); c->req.mutable_graph_def() - >Swap(&graph_partitions[part.name]);
    StripDefaultAttributes(*OpRegistry::Global(),
                           c->req.mutable_graph_def() - >mutable_node());
    *c->req.mutable_config_proto() = session_opts_.config;
    *c->req.mutable_graph_options() = session_opts_.config.graph_options(a); *c->req.mutable_debug_options() =
        callable_opts_.run_options().debug_options(a); c->req.set_collective_graph_key(collective_graph_key_);

    auto cb = [c, &done](const Status& s) {
      c->status = s;
      done.DecrementCount(a); }; part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
  }
  done.Wait(a);for (int i = 0; i < num; ++i) {
    Call* c = &calls[i];
    s.Update(c->status);
    partitions_[i].graph_handle = c->resp.graph_handle(a); }return s;
}

Copy the code

When used, graph_handle is used to uniquely identify a subgraph.

// Asynchronously deregisters subgraphs on the workers, without waiting for the
// result.
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
  struct Call {
    DeregisterGraphRequest req;
    DeregisterGraphResponse resp;
  };
  for (Part& part : partitions_) {
    // The graph handle may be empty if we failed during partition registration.
    if(! part.graph_handle.empty()) {
      Call* c = new Call;
      c->req.set_session_handle(session_handle_);
      c->req.set_create_worker_session_called(! should_deregister_); c->req.set_graph_handle(part.graph_handle);
      // NOTE(mrry): We must capture worker_cache_ since this
      // could be deleted before the callback is called.
      WorkerCacheInterface* worker_cache = worker_cache_;
      const string name = part.name;
      WorkerInterface* w = part.worker;
      CHECK_NOTNULL(w);
      auto cb = [worker_cache, c, name, w](const Status& s) {
         delete c;
        worker_cache->ReleaseWorker(name, w);
      };
      w->DeregisterGraphAsync(&c->req, &c->resp, cb); }}}Copy the code

3.3.4 Derived class of WorkerInterface

As shown in the figure below, there are two implementations of WorkerInterface.

  • GrpcWorker: Worker role in local mode. If the Master and Worker are both local, they can be called directly without RPC network transmission.
  • GrpcRemoteWorker: In distributed mode, the Worker is located at the remote end, and the local user needs to use GrpcRemoteWorker to access the remote Worker.
    • GrpcRemoteWorker is a gRPC client, which accesses GrpcWorkerService on the remote Worker through the stub.
    • GrpcWorkerService implements all interfaces defined by WorkerService, but the actual work is forwarded to the local GrpcWorker.

Specific examples are as follows:

Figure 1 WorkerInterface derived class

3.3.5 use

During Server initialization, the Worker Service is set up with the following code.

  // Create GrpcWorker and GrpcWorkerService
  worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                  : NewGrpcWorker(&worker_env_, config);
  worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
                                         opts.worker_service_options)


Copy the code

GrpcWorkerService is returned.

// Returns an implementation of WorkerService rpc service.
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService( GrpcWorker* worker, ::grpc::ServerBuilder* builder, GrpcWorkerServiceOptions options) {
  return std::unique_ptr<AsyncServiceInterface>(
      new GrpcWorkerService(worker, builder, options));
}


Copy the code

GrpcServer uses the worker_thread_ thread to execute the HandleRPCsLoop method of GrpcWorkerService.

worker_thread_.reset(
    env_->StartThread(ThreadOptions(), "TF_worker_service"[this] { worker_service_->HandleRPCsLoop(a); }));Copy the code

3.3.6 definition

GrpcWorkerService is defined as follows. Since it needs to act as a daemon to handle incoming gRPC requests, threads are created in the constructor to respond to requests, and these threads are then started in the HandleRPCsLoop to Join.

class GrpcWorkerService : public AsyncServiceInterface {
 public:
  GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
                    GrpcWorkerServiceOptions options)
      : is_shutdown_(false) {
    builder->RegisterService(&worker_service_);

    for (int i = 0; i < options.num_serving_threads; i++) {
      threads_.emplace_back(
          new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
                                      cache_.get(), &worker_service_)); }}// This method blocks forever handling requests from the completion queue.
  void HandleRPCsLoop(a) override {
    for (auto& worker_thread : threads_) {
      worker_thread->Start(a); }for (auto& worker_thread : threads_) {
      worker_thread->Join();
    }
  }

 private:
  grpc::WorkerService::AsyncService worker_service_;
  std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;

  std::unique_ptr<GrpcResponseCache> cache_;
  mutex service_shutdown_mu_;
  bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);

  TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
};

Copy the code

3.3.7 thread

The loop and response requests are actually done in threads, and CQ_ is the completion queue for GRPC.

// GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
// requests. Each thread operates on an independent completion queue.
class GrpcWorkerServiceThread {
 public:
  explicit GrpcWorkerServiceThread(
      GrpcWorker* worker, ::grpc::ServerBuilder* builder,
      std::unordered_map<int.int> queue_depth, GrpcResponseCache* cache,
      grpc::WorkerService::AsyncService* worker_service)
      : worker_(worker),
        queue_depth_(queue_depth),
        cache_(cache),
        worker_service_(worker_service),
        is_shutdown_(false) {
    cq_ = builder->AddCompletionQueue(a); }void Start(a) {
    thread_.reset(
        worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service"[this] () {HandleRPCsLoop(); }));
  }
}

Copy the code
The main loop

GrpcWorkerServiceThread: : thread HandleRPCsLoop is the main loop, and master the service. Here are some waiting queues for gRPC calls that correspond to the GrpcWorkerMethod. The code for each method will be mentioned later.

// Add one or more completion queue entries for each worker method, then
// begin servicing requests from the completion queue.
void GrpcWorkerServiceThread::HandleRPCsLoop(a) {
  // TODO(ncteisen): This may require performance engineering. We can
  // change the number of threads, the number of handlers per thread,
  // or even decide to specialize certain threads to certain methods.
  SETUP_FOR_REQUEST(GetStatus, 1.false);
  SETUP_FOR_REQUEST(CreateWorkerSession, 1.false);
  SETUP_FOR_REQUEST(DeleteWorkerSession, 1.false);
  SETUP_FOR_REQUEST(CleanupAll, 1.false);
  SETUP_FOR_REQUEST(RegisterGraph, 1.false);
  SETUP_FOR_REQUEST(DeregisterGraph, 1.false);
  SETUP_FOR_REQUEST(Logging, 1.false);
  SETUP_FOR_REQUEST(Tracing, 1.false);
  SETUP_FOR_REQUEST(CompleteGroup, 10.true);
  SETUP_FOR_REQUEST(CompleteInstance, 10.true);
  SETUP_FOR_REQUEST(GetStepSequence, 10.true);
  SETUP_FOR_REQUEST(RecvBuf, 500.true);
  SETUP_FOR_REQUEST(RunGraph, 100.true);
  SETUP_FOR_REQUEST(CleanupGraph, 100.false);
  SETUP_FOR_REQUEST(MarkRecvFinished, 10.false);

  // TODO(ncteisen): Determine a better policy for enqueuing the
  // appropriate number of each request type.
  for (int i = 0;
       i < gtl::FindWithDefault(
               queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
               1000);
       ++i) {
    EnqueueRecvTensorRequestRaw(a); }void* tag;
  bool ok;

  while (cq_->Next(&tag, &ok)) {
    UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
        static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
    CHECK(callback_tag);
    callback_tag->OnCompleted(this, ok); }}Copy the code
grpc request

Requests are handled similarly to master. Each request handler will be called to a business, such as the following macro definition GrpcWorkerServiceThread: : method# # handler.

#define ENQUEUE_REQUEST(method, supports_cancel)                             \
  do {                                                                       \
    mutex_lock l(shutdown_mu_);                                              \
    if(! is_shutdown_) { \ Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService, \ method##Request, method##Response>:: \ EnqueueRequestForMethod( \ worker_service_, cq_.get(), \ static_cast<int>(GrpcWorkerMethod::k##method),                 \
              &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
    }                                                                        \
  } while (0)

#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
  for (int i = 0;                                                              \
       i < gtl::FindWithDefault(queue_depth_,                                  \
                                static_cast<int>(GrpcWorkerMethod::k##method), \ default_depth); \ ++i) { \ ENQUEUE_REQUEST(method, supports_cancel); The \}

Copy the code

You need to register each RPC service as an asynchronous service, which is done using the AddMethod and MarkMethodAsync interfaces that come with gRPC.

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i); }}Copy the code
Handler & thread pool

Compute_pool ->Schedule Specifies whether to use the thread pool compute_pool->Schedule. This uses the module integrated with the Worker env.

  // Handle all non-cancellable simple methods with a standard wrapper.
  // The boolean may_block_on_compute_pool indicates whether or not the
  // operation may block on activities (such as op execution) that run on the
  // compute pool.
#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if(! s.ok()) { \ VLOG(3) <<"Bad response from " << #method << ":"<< s; \ } \ call->SendResponse(ToGrpcStatus(s)); The \}; \if((may_block_on_compute_pool)) { \ worker_->env()->env->SchedClosure(std::move(closure)); The \}else{ \ worker_->env()->compute_pool->Schedule(std::move(closure)); \ } \ ENQUEUE_REQUEST(method, false); The \}

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

#undef HANDLE_CALL
Copy the code
Messages & Methods

GrpcWorkerMethod defines the specific methods of worker.

// Names of worker methods.
enum class GrpcWorkerMethod {
  kGetStatus,
  kCreateWorkerSession,
  kDeleteWorkerSession,
  kRegisterGraph,
  kDeregisterGraph,
  kRunGraph,
  kCleanupGraph,
  kCleanupAll,
  kRecvTensor,
  kRecvBuf,
  kLogging,
  kTracing,
  kCompleteGroup,
  kCompleteInstance,
  kGetStepSequence,
  kMarkRecvFinished,
};

Copy the code

Which methods these message names correspond to is done by GrpcWorkerMethodName.

const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
  switch (id) {
    case GrpcWorkerMethod::kGetStatus:
      return "/tensorflow.WorkerService/GetStatus";
    case GrpcWorkerMethod::kCreateWorkerSession:
      return "/tensorflow.WorkerService/CreateWorkerSession";
    case GrpcWorkerMethod::kDeleteWorkerSession:
      return "/tensorflow.WorkerService/DeleteWorkerSession";
    case GrpcWorkerMethod::kRegisterGraph:
      return "/tensorflow.WorkerService/RegisterGraph";
    case GrpcWorkerMethod::kDeregisterGraph:
      return "/tensorflow.WorkerService/DeregisterGraph";
    case GrpcWorkerMethod::kRunGraph:
      return "/tensorflow.WorkerService/RunGraph";
    case GrpcWorkerMethod::kCleanupGraph:
      return "/tensorflow.WorkerService/CleanupGraph";
    case GrpcWorkerMethod::kCleanupAll:
      return "/tensorflow.WorkerService/CleanupAll";
    case GrpcWorkerMethod::kRecvTensor:
      return "/tensorflow.WorkerService/RecvTensor";
    case GrpcWorkerMethod::kRecvBuf:
      return "/tensorflow.WorkerService/RecvBuf";
    case GrpcWorkerMethod::kLogging:
      return "/tensorflow.WorkerService/Logging";
    case GrpcWorkerMethod::kTracing:
      return "/tensorflow.WorkerService/Tracing";
    case GrpcWorkerMethod::kCompleteGroup:
      return "/tensorflow.WorkerService/CompleteGroup";
    case GrpcWorkerMethod::kCompleteInstance:
      return "/tensorflow.WorkerService/CompleteInstance";
    case GrpcWorkerMethod::kGetStepSequence:
      return "/tensorflow.WorkerService/GetStepSequence";
    case GrpcWorkerMethod::kMarkRecvFinished:
      return "/tensorflow.WorkerService/MarkRecvFinished";
  }
  // Shouldn't be reached.
  return "invalid id";
}

Copy the code

GrpcWorkerMethodName is called in the AsyncService to complete the GRPC registration.

WorkerService::AsyncService::AsyncService() {
  for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
    AddMethod(new ::grpc::internal::RpcServiceMethod(
        GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
        ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
    ::grpc::Service::MarkMethodAsync(i); }}Copy the code
The business process

Specific business processing is completed by calling Worker.

void GetStepSequenceHandler( WorkerCall
       
        * call)
       ,> {
  Schedule([this, call]() {
    worker_->GetStepSequenceAsync(
        &call->request, &call->response, [call](const Status& s) {
          call->SendResponse(ToGrpcStatus(s));
        });
  });
  ENQUEUE_REQUEST(GetStepSequence, true);
}

Copy the code

For now, from a threading perspective, the logic is as follows, assuming there are three threads. Server worker_thread_ launched GrpcWorkerService: : HandleRPCsLoop (), its role is to start two GrpcWorkerServiceThread, Each GrpcWorkerServiceThread GrpcWorkerServiceThread: : respond to gRPC HandleRPCsLoop of request, business processing. Note that both GrpcWorkerService and GrpcWorkerServiceThread have HandleRPCsLoop methods.

Figure 2 Thread perspective

3.3.8 Service Logic

CreateWorkerSession

CreateWorkerSessionRequest message passes MasterSession corresponding session_handle, Worker receives the message, to generate a WorkerSession. In a cluster, when MasterSession establishes a WorkerSession, it passes its session_handle. The WorkerSession can use session_handle to know which MasterSession it belongs to. The MasterSession instance can also centrally manage all workerSessions that belong to it.

GrpcWorker manages the WorkerSession through SessionMgr. The WorkerSession can be determined by master task name or session_handle.

class SessionMgr {

  WorkerEnv* const worker_env_;  // Not owned.
  std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
  std::shared_ptr<WorkerSession> legacy_session_;
  const WorkerCacheFactory worker_cache_factory_;

  // A map from session identifier to internal session structure.
  std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);

  // Incarnation and WorkerSession handle associated with a master task.
  struct MasterAssociatedSession {
    const int64_t master_incarnation;
    const string session_handle;
  };
  // A map from master task name to its associated worker sessions.
  std::unordered_multimap<string, MasterAssociatedSession>
      master_to_associated_sessions_ TF_GUARDED_BY(mu_);
};

Copy the code

The specific information is as follows, please notice, CreateWorkerSessionResponse did not return anything:

message CreateWorkerSessionRequest {
  // Sessions are identified by a given handle.
  string session_handle = 1;

  // Defines the configuration of a TensorFlow worker.
  ServerDef server_def = 2;

  // If true, any resources such as Variables used in the session will not be
  // shared with other sessions.
  bool isolate_session_state = 3;

  // The device attributes of all the devices in the cluster.
  repeated DeviceAttributes cluster_device_attributes = 4;

  // The master task name from which the request is sent.
  string master_task = 5;

  // The incarnation ID of the master task local CPU device.
  // If the target worker already has a WorkerSession created previously with
  // the same master task name but a different incarnation, it usually indicates
  // that the previous master failed before deleting the WorkerSession on the
  // worker. To prevent memory leaks, the worker should garbage collect the old
  // WorkerSessions.
  int64 master_incarnation = 6;
}

message CreateWorkerSessionResponse {}

Copy the code

Figure 3 CreateWorkerSession

As mentioned earlier, GrpcWorker messages are generated using macros.

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if(! s.ok()) { \ VLOG(3) <<"Bad response from " << #method << ":"<< s; \ } \ call->SendResponse(ToGrpcStatus(s)); The \}; \if((may_block_on_compute_pool)) { \ worker_->env()->env->SchedClosure(std::move(closure)); The \}else{ \ worker_->env()->compute_pool->Schedule(std::move(closure)); \ } \ ENQUEUE_REQUEST(method, false); The \}

  HANDLE_CALL(GetStatus, false);
  HANDLE_CALL(CreateWorkerSession, false);
  HANDLE_CALL(DeleteWorkerSession, true);
  HANDLE_CALL(CleanupAll, false);
  HANDLE_CALL(RegisterGraph, false);
  HANDLE_CALL(DeregisterGraph, false);
  HANDLE_CALL(CleanupGraph, false);
  HANDLE_CALL(Logging, false);
  HANDLE_CALL(Tracing, false);

Copy the code
RegisterGraph

The RegisterGraphRequest message sends the session_handle, subgraph graph_def corresponding to MasterSession. When the Worker receives the message and completes the subgraph registration/initialization, it will return the graph_handle of the subgraph to the Master.

For each session, after the master has placed each node on a device, it splits the entire graph into many subgraphs. All nodes in a subgraph are in the same worker, but may be on many devices owned by the worker (e.g. Cpu0, plus gpu0, gpu1,… , gpu7). Before running any step, the master registers the subgraph for the worker. Successful registration returns a handle to the graph that can be used in future RunGraph requests.

////////////////////////////////////////////////////////////////////////////////
//
// RegisterGraph method request/response messages
//
// For each session, after the master placed every node on a device,
// it partitions the whole graph into many subgraphs. All the nodes in
// a subgraph were in the same worker, but potentially on many devices
// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ... , gpu7). The
// master registers subgraphs for a worker before running any steps. A
// successful registration returns a graph handle to be used in latter
// RunGraph requests.
//
////////////////////////////////////////////////////////////////////////////////

message RegisterGraphRequest {
  // Subgraphs are scoped within one session.
  string session_handle = 1;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 6;

  // "graph_def" has the subgraph of nodes for this worker, with each node
  // having its device_name filled in.
  GraphDef graph_def = 2;

  // True iff the graph (before partitioning) contains control flow nodes.
  //
  // As of 01/11/2015, this is no longer set by clients.
  bool has_control_flow = 3 [deprecated = true];

  // Configuration options for the session in which this graph was created.
  GraphOptions graph_options = 4;

  // Field(s) used by TensorFlow Debugger (tfdbg).
  DebugOptions debug_options = 5;

  // If graph_def contains any collective ops this must be a positive
  // integer used to coordinate execution with other graphs. All
  // graphs in a distributed execution with the same
  // collective_graph_key will coordinate to use the same step_id
  // concurrently so that BufRendezvous entries will make the correct
  // values accessible.
  int64 collective_graph_key = 7;

  // ConfigProto from the session in which this graph was created.
  // Contains additional parameters beyond graph_options, including
  // the name of the requested executor.
  ConfigProto config_proto = 8;
}

message RegisterGraphResponse {
  // If the registration succeeds, returns an opaque graph_handle to
  // the master. The master calls RunGraph with graph_handle to
  // compute different steps.
  string graph_handle = 1;
}

Copy the code

Figure 4 RegisterGraph

DeregisterGraph

When a graph is no longer needed (for example, if the entire graph is rescheduled and nodes within the graph are rearranged), the Master uses the graph_handle to unregister the graph. In the case of Master restart, the Worker automatically cancels the registration of the corresponding graph_handle according to the ttL-based policy.

////////////////////////////////////////////////////////////////////////////////
//
// DeregisterGraph method request/response messages
//
// The master deregisters the given graph_handle when the graph is no
// longer needed (e.g., the overall graph is re-scheduled and nodes
// are re-placed).
//
// The worker deregisters a graph_handle automatically according to on
// a TTL-base policy in case of master restarts.
//
////////////////////////////////////////////////////////////////////////////////

message DeregisterGraphRequest {
  // The session_handle used when registering the graph. If session_handle is
  // empty, a single global namespace is used.
  string session_handle = 2;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 3;

  // REQUIRED: graph_handle must be returned by a RegisterGraph call
  // to the same WorkerService.
  string graph_handle = 1;
}

message DeregisterGraphResponse {
  // TODO(mrry): Optionally add summary stats for the graph.
}

Copy the code

Figure 5 DeregisterGraph

RunGraph

The Master uses the RunGraphRequest to execute all subgraphs registered under Graph_handle.

The Master generates a globally unique step_ID to distinguish between the different running steps of the graph calculation. Subgraphs can communicate with each other (for example, send/forward operations) using step_id to distinguish tensors produced by different runs.

The send of the RunGraphRequest message indicates the subgraph input tensor, and recv_key indicates the subgraph output tensor. RunGraphResponse will return the Tensor list for recv_key.

Figure 6 RunGraph

////////////////////////////////////////////////////////////////////////////////
//
// RunGraph request / response messages
//
// The worker executes all subgraphs registered under graph_handle.
// RunGraph returns after the execution finishes or an error is
// encountered.
// A sequence of RunGraphRequests with is_partial may be sent to RunGraph for
// partial graph execution.
//
////////////////////////////////////////////////////////////////////////////////

// Options specific to the execution of a single step.
message ExecutorOpts {
  bool record_costs = 1;
  bool record_timeline = 3;
  bool record_partition_graphs = 4;
  bool report_tensor_allocations_upon_oom = 5;
}

message RunGraphRequest {
  // session_handle is the master-generated unique id for this session.
  // If session_handle is non-empty, it must be the same as used when
  // registering the graph. If it is empty, a single global namespace is used to
  // search for the graph_handle.
  string session_handle = 8;

  // Set to true if CreateWorkerSession was called for session_handle.
  bool create_worker_session_called = 10;

  // REQUIRED: graph_handle must be returned by a RegisterGraph call
  // to the same WorkerService.
  string graph_handle = 1;

  // A unique ID to distinguish different runs of the same graph.
  //
  // The master generates a global unique step_id to distinguish
  // different runs of the graph computation. Subgraphs communicate
  // (e.g., send/recv ops) with each other using step_id to
  // distinguish tensors generated by different runs.
  int64 step_id = 2;

  // Options for this step.
  ExecutorOpts exec_opts = 5;

  // Runs the graph.
  //
  // Sends the tensors in "send" into the graph before the run and
  // fetches the keys into RunGraphResponse.recv after the run.
  repeated NamedTensorProto send = 3;
  repeated string recv_key = 4;

  // True if the RunGraphRequest is a partial run request.
  bool is_partial = 6;
  // True if this is the last partial run request in a sequence of requests.
  bool is_last_partial_run = 7;

  // If true then some errors, e.g., execution errors that have long
  // error messages, may return an OK RunGraphResponse with the actual
  // error saved in the status_code/status_error_message fields of the
  // response body. This is a workaround since the RPC subsystem may
  // truncate long metadata messages.
  bool store_errors_in_response_body = 9;

  // Unique identifier for this request. Every RunGraphRequest must have a
  // unique request_id, and retried RunGraphRequests must have the same
  // request_id. If request_id is zero, retry detection is disabled.
  //
  // Retried RunGraphRequests are problematic because they may issue a
  // RecvTensor that will have no corresponding sender and will wait forever.
  // Workers use request_ids to reject retried RunGraph requests instead of
  // waiting forever.
  int64 request_id = 11;

  // Next: 12
}

message RunGraphResponse {
  // A list of tensors corresponding to those requested by
  // RunGraphRequest.recv_key.
  repeated NamedTensorProto recv = 1;

  // If the request asked for execution stats, the cost graph, or the partition
  // graphs, these are returned here.
  // TODO(suharshs): Package these in a RunMetadata instead.
  StepStats step_stats = 2;
  CostGraphDef cost_graph = 3;
  repeated GraphDef partition_graph = 4;

  // If store_errors_in_response_body is true in the request, then
  // optionally the server may return an OK status for the RPC and
  // fill the true status into the fields below, to allow for messages
  // that are too long to fit in metadata.
  error.Code status_code = 5;
  string status_error_message = 6;
}

Copy the code
RecvTensor

During the specific operation, data may be exchanged between the two workers. At this time, the producer only puts the prepared tensor into rendezvous, and the consumer will initiate RecvTensorRequest. The RecvTensorRequest step_id identifies which step, rendezvous_key identifies the channel to receive the tensor.

A RecvTensor request takes a tensor from the channel, or you can send and receive multiple tensors in the same channel through multiple RecvTensor requests. The final producer’s tensor is returned to the consumer via the RecvTensorResponse.

Figure 7 RecvTensor

////////////////////////////////////////////////////////////////////////////////
//
// RecvTensor method request/response messages
//
////////////////////////////////////////////////////////////////////////////////

message RecvTensorRequest {
  // The step in which the tensor will be produced.
  //
  // REQUIRED: This must eventually correspond to the step_id passed
  // into a RunGraph call on the same WorkerService.
  int64 step_id = 1;

  // A key identifying the channel to receive tensors from. A RecvTensor request
  // retrieves one tensor from the channel, but multiple tensors can be sent and
  // received over the same channel with multiple RecvTensor requests. See
  // rendezvous.h for details.
  string rendezvous_key = 2;

  // If true, use an out-of-band DMA mechanism to transfer the
  // received tensor.
  bool dma_ok = 3;

  // Optional information on client-side device locality.
  DeviceLocality client_locality = 4;

  // Optional information on server-side device locality.
  DeviceLocality server_locality = 5;

  // Optional information needed by the RPC subsystem.
  google.protobuf.Any transport_options = 6;

  // Unique identifier for this request. Every RecvTensorRequest must have a
  // unique request_id, and retried RecvTensorRequests must have the same
  // request_id. If request_id is zero, retry detection and response cache
  // are disabled.
  //
  // Retried RecvTensorRequests are problematic because a RecvTensor with no
  // corresponding sender will wait forever, and the tensor may have been
  // delivered to a previous retry. Workers use request_ids to reject retried
  // RecvTensor requests instead of waiting forever.
  int64 request_id = 7;
}

message RecvTensorResponse {
  // The tensor as a proto.
  TensorProto tensor = 1;

  // If true, this tensor was the output of a dead node, and the
  // content is invalid.
  bool is_dead = 2;

  // The time at which tensor was available and started to be returned.
  int64 send_start_micros = 3;

  // Optional additional information about how to receive the tensor,
  // e.g. in the event that RecvTensorRequest.dma_ok was true.
  google.protobuf.Any transport_options = 4;

  // Whether the receiver should send a MarkRecvFinishedRequest to the sender
  // to ack the message.
  bool require_ack = 5;
}

Copy the code

4. Worker

The Worker class primarily provides WorkerEnv and PartialRunMgr, which can be subclassed to provide specialized implementations of specific methods for different transport mechanisms. For example, GrpcWorker specifically implements the RecvTensorAsync method to support more efficient gRPC data structures for handling large binary data.

class Worker : public WorkerInterface {
 protected:
  WorkerEnv* const env_;  // Not owned.
  RecentRequestIds recent_request_ids_;

 private:
  PartialRunMgr partial_run_mgr_;

  CancellationManager cancellation_manager_;

  TF_DISALLOW_COPY_AND_ASSIGN(Worker);
};

Copy the code

Let’s take one of them, and we’ll talk about the other ones when we come across them.

void Worker::CleanupAllAsync(const CleanupAllRequest* request,
                             CleanupAllResponse* response,
                             StatusCallback done) {
  std::vector<string> containers;
  for (const auto& c : request->container()) containers.push_back(c);
  env_->device_mgr->ClearContainers(containers);
  done(Status::OK());
}

Copy the code

5. GrpcWorker

GrpcWorker is the remote Worker corresponding to GrpcRemoteWorker. It is also the object called by GrpcWorkerService, which implements the business logic. It is defined as follows, and we can see that it implements several methods.

class GrpcWorker : public Worker {
 public:
  GrpcWorker(WorkerEnv* env, const ConfigProto& config);

  // Specialized version of RecvTensor for gRPC, which avoids a copy.
  virtual void GrpcRecvTensorAsync(CallOptions* opts,
                                   const RecvTensorRequest* request,
                                   ::grpc::ByteBuffer* response,
                                   StatusCallback done);

  void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
                    StatusCallback done) override;

  void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                    RecvBufResponse* response, StatusCallback done) override;

  void CleanupGraphAsync(const CleanupGraphRequest* request,
                         CleanupGraphResponse* response,
                         StatusCallback done) override;

  WorkerEnv* env(a);

  void EnableResponseCache(a);

  void RemoveCacheEntryForId(int64 request_id);

 private:
  std::unique_ptr<GrpcResponseCache> response_cache_;
  const int32 recv_buf_max_chunk_;
};

Copy the code

So far, the static structure of Worker has been introduced. Specific Worker functions will be introduced in the Session section later.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

TensorFlow Internals

TensorFlow Architecture and Design: Overview

TensorFlow kernel analysis

TensorFlow Architecture and Design: OP Essentialism

TensorFlow whitepaper

Tensorflow Developer Summit 2017

Jcf94.com/2018/02/28/…

TensorFlow 拆包(五):Distributed

TensorFlow Architecture

Tensorflow (Tensorflow)

What are In-graph replication and between-graph replication?

TensorFlow (1): create a session

05tensorflow Distributed session

Section 8, configure distributed TensorFlow

TensorFlow Distributed TensorFlow

Distributed_runtime for tensorflow source code parsing

Distributed TensorFlow: A Gentle Introduction

This article explains the essential knowledge of Tensorflow distributed training

Placer, the Placement heuristic algorithm module in TensorFlow

Graph Partitioner for TensorFlow

A communication mechanism for TensorFlow

TensorFlow distributed pit mining

TensorFlow: Distributed execution of model optimization

Tensorflow architecture process]