Before detailing the various strategies of TensorFlow distribution, we first need to look at the foundation of distribution: the distributed environment. Only lay a solid foundation, in order to the greatest extent in the later analysis of the work to clear away obstacles, get twice the result with half the effort. Session mechanism is the core of TensorFlow distributed runtime. Let’s follow the process from Client to worker and go through Session mechanism from front to back.

The other articles in this series are:

Heterogeneous Distribute Learning based on TensorFlow distributed thesis [翻译

Implementation of Control Flow in TensorFlow

TensorFlow Distributed environment (1) — overall architecture

TensorFlow distributed environment (2)– Master static logic

TensorFlow distributed environment (3)- Worker static logic

TensorFlow distributed environment (4) — WorkerCache

1. An overview of the

1.1 classification of the Session

Distributed mode is controlled by the following sessions collaborating with each other:

  • GrpcSession resides on the Client and controls the session life cycle of the Client.
  • A MasterSession resides on the Master. Multiple clients may access the same Master at the same time. The Master constructs a MasterSession for each Client. MasterSession Controls the lifetime of Master sessions.
  • The WorkerSession is located on the Worker. There may be multiple masters connected to the same Worker, and the Worker will create a WorkerSession for each Master. WorkerSession Controls the session life cycle of the Worker.

As shown in the figure below, both Master and Worker are servers. Each Server runs a MasterService and a WorkerService. Each Server may play different roles. It depends on how the user configures the compute graph and cluster. Because of this two-layer one-to-many relationship, in order to distinguish such different data flows and control relations, the three sessions with logical relationship are bound to the same Session_handle, and each session_handle identifies a complete data flow.

Figure 1 Session relationship

1.2 Session Flow

Let’s start with GrpcSession, which has the following basic functions:

  • Create a session
    • Obtain the remote device set.
    • Create MasterSession on Master;
    • Create a WorkerSession on each Worker;
  • iteration
    • Start execution;
    • Figure divided;
    • Registration subgraph;
    • Running subgraph;
  • Close the session
    • Close the MasterSession
    • Close the WorkerSession;

1.2.1 MasterSession Life Cycle

In distributed mode, the Master runtime is controlled by MasterSession and its lifecycle is shown below.

Figure 2 MasterSession lifecycle

1.2.2 WorkerSession Life Cycle

In distributed mode, Worker runs under the control of WorkerSession, and its life cycle is shown in the figure below.

Figure 3 WorkerSession lifecycle

2. GrpcSession

GrpcSession tensorflow: : GRPC: : simple encapsulation of MasterService. It uses the remote device set as the computing resource and GRPC as the remote invocation mechanism to let the caller calculate the TensorFlow graph on the remote device.

2.1 define

We still give only a member variable defined and some important function, it is to take advantage of master_ tensorflow: : GRPC: : MasterService make calls.

class GrpcSession : public Session {
  // There are several ways to create it
  Status Create(const GraphDef& graph) override;
  Status Create(const RunOptions& run_options, const GraphDef& graph) override;
  Status Create(GraphDef&& graph) override;
  Status Create(const RunOptions& run_options, GraphDef&& graph) override;  
  
 private:
  const SessionOptions options_;
  std::unique_ptr<MasterInterface> master_;
  mutex mu_;

  // handle_ returned by the master to identify this session.
  string handle_ TF_GUARDED_BY(mu_);

  // The current version of the graph.
  int64_t current_graph_version_ TF_GUARDED_BY(mu_);

  bool is_local_ = false;
};
Copy the code

2.2 Registration & Factory Classes

The use of GrpcSession is done through factory classes such as:

Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);
  if(! s.ok()) {
    *out_session = nullptr;
    return s;
  }
  // Starts exporting metrics through a platform-specific monitoring API (if
  // provided). For builds using "tensorflow/core/platform/default", this is
  // currently a no-op.
  session_created->GetCell() - >Set(true);
  s = factory->NewSession(options, out_session);
  if(! s.ok()) {
    *out_session = nullptr;
  }
  return s;
}
Copy the code

A GrpcSession is polymorphically created by a GrpcSessionFactory. If a protocal uses “GRPC ://”, a GrpcSession is generated. The GrpcSessionFactory implementation is registered to the system.

const char* const kSchemePrefix = "grpc://";
const size_t kSchemePrefixLength = strlen(kSchemePrefix);

class GrpcSessionFactory : public SessionFactory {
 public:
  bool AcceptsOptions(const SessionOptions& options) override {
    return absl::StartsWith(options.target, kSchemePrefix);
  }

  Status NewSession(const SessionOptions& options,
                    Session** out_session) override {
    std::unique_ptr<GrpcSession> session;
    TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session));
    *out_session = session.release(a);return Status::OK(a); }// Invokes the session specific static method to reset containers.
  Status Reset(const SessionOptions& options,
               const std::vector<string>& containers) override {
    return GrpcSession::Reset(options, containers); }};class GrpcSessionRegistrar {
 public:
  GrpcSessionRegistrar() {
    SessionFactory::Register("GRPC_SESSION".new GrpcSessionFactory()); }};static GrpcSessionRegistrar registrar;
Copy the code

2.3 create GrpcSession

The GrpcSession::Create method does the job. The Client invokes the Master Service through the GrpcSession. How does the Client interact with the Master Service? Use MasterInterface.

So, the most important thing here is how to build the MasterInterface instance. As mentioned earlier, there are two implementations of MasterInterface, both of which are used to communicate with the Master Service in different application scenarios.

  • LocalMaster is used for direct communication between processes. In this case, the Client and Master are in the same process.
  • GrpcRemoteMaster uses Grpc to communicate with the Master service. The Client and Master are deployed in two different processes. GrpcRemoteMaster implements the gRPC client. It accesses MasterService on the remote Master through the Stub.

The two rectangles encapsulating the Master represent the actual Master class, which implements specific Master functions.

Figure 1. Master logic

As you can see from the following code, GrpcSession is created based on options.target, which is “GRPC ://”, If the LocalMaster class is found using the LocalMaster::Lookup method, use it directly. If it is not found, use NewGrpcMaster to generate a GrpcRemoteMaster.

/* static */
Status GrpcSession::Create(const SessionOptions& options,
                           std::unique_ptr<GrpcSession>* out_session) {
  std::unique_ptr<GrpcSession> session(new GrpcSession(options));
  std::unique_ptr<MasterInterface> master;
  // For testing, we enable the client to disable the use of the local
  // master registry, so that the RPC stack is exercised.
  if(! options.config.rpc_options().use_rpc_for_inprocess_master()) {
    master = LocalMaster::Lookup(options.target);
  }
  if(! master) { SharedGrpcChannelPtr master_channel;TF_RETURN_IF_ERROR(
        NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
                               &options.config.rpc_options(), &master_channel));
    master.reset(NewGrpcMaster(master_channel));
  } else {
    session->is_local_ = true;
  }
  session->SetRemoteMaster(std::move(master));
  *out_session = std::move(session);
  return Status::OK(a); }Copy the code

2.4 create MasterSession

After GrpcSession is created, the MasterSession is created using GrpcSession::Create(graph_def). GrpcSession::Create(graph_def) builds the CreateSessionRequst message and sends the initial calculation graph to the Master via GrpcRemoteMaster. When the Master receives the CreateSessionRequst message, it constructs the corresponding MasterSession and returns the CreateSessionRequst response to the GrpcSession.

  • Session_handle of this MasterSession. Identifies the MasterSession instance on the Master side
  • The version number of the initial computed graph graph_version. Used to initiate subsequent ExtendSession operations, such as appending new nodes to the original graph.

Figure 2 Creating MasterSession

The code is as follows, starting with two create methods that eventually call CreateImpl.

Status GrpcSession::Create(const RunOptions& run_options,
                           const GraphDef& graph) {
  return Create(run_options, GraphDef(graph));
}

Status GrpcSession::Create(GraphDef&& graph) {
  CallOptions call_options;
  call_options.SetTimeout(options_.config.operation_timeout_in_ms());
  return CreateImpl(&call_options, std::move(graph));
}
Copy the code

The CreateImpl method looks like this:

Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) {{mutex_lock l(mu_);
    if(! handle_.empty()) {
      return errors::InvalidArgument("A session is alive.");
    }
  }
  CreateSessionRequest req;
  *req.mutable_config() = options_.config;
  req.mutable_graph_def() - >Swap(&graph);
  req.set_target(options_.target);
  ReEncodeConsts(req.mutable_graph_def());
  CreateSessionResponse resp;
  Status s = master_->CreateSession(call_options, &req, &resp);
  if (s.ok()) {
    SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version());
  }
  return s;
}
Copy the code

Against 2.4.1 GrpcRemoteMaster: : CreateSession

GrpcRemoteMaster is a gRPC Client implementation located in the Client. Its CreateSession method only calls the remote MasterService CreateSession interface through the gRPC stub. Send a CreateSessionRequest.

Status CreateSession(CallOptions* call_options,
                     const CreateSessionRequest* request,
                     CreateSessionResponse* response) override {
  return CallWithRetry(call_options, request, response,
                       &MasterServiceStub::CreateSession);
}
Copy the code

2.4.2 GrpcMasterService: : CreateSessionHandler

CreateSessionRequest GrpcMasterService gRPC service provided by the Master, received the message, after service invocation GrpcMasterService: : CreateSessionHandler to process the message, The actual business processing is done by master_impl_ (an instance of the Master class), which calls Master::CreateSession.

When the master_IMPL_ process is complete, the CreateSessionResponse response is returned to the Client.

// RPC handler for creating a session.
void CreateSessionHandler( MasterCall
       
        * call)
       ,> {
  CreateSessionRequest* rewritten_req = new CreateSessionRequest;
  rewritten_req->mutable_config() - >MergeFrom(default_session_config_);
  rewritten_req->MergeFrom(call->request);
  master_impl_->CreateSession(rewritten_req, &call->response,
                              [call, rewritten_req](const Status& status) {
                                call->SendResponse(ToGrpcStatus(status));
                                delete rewritten_req;
                              });
  ENQUEUE_REQUEST(CreateSession, true);
}
Copy the code

2.4.3 Master: : CreateSession

Master::CreateSession obtains a thread from the thread pool and does the following:

  • If clust_spec is defined, all workers are found according to the configuration.
  • Obtain the remote device.
  • Obtain the remote worker.
  • Establish MasterSession through Factory.
  • With worker_cache_Factory, let MasterSession establish a WorkerSession session.
  • Insert into the

    binary. Then the Master can use session_handle to obtain the corresponding MasterSession.
    ,>
void Master::CreateSession(const CreateSessionRequest* req,
                           CreateSessionResponse* resp, MyClosure done) {
  SchedClosure([this, req, resp, done]() {
    Status status;
    WorkerCacheFactoryOptions worker_cache_factory_options;
    string grpc_protocol("grpc");
    worker_cache_factory_options.protocol = &grpc_protocol;
    auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
    status = ValidateExternalGraphDefSyntax(req->graph_def());
    if(! status.ok()) return;

    // The following 4 variables are set differently, depending on whether this
    // session uses a client-provided clusterspec or not.
    WorkerCacheInterface* worker_cache = nullptr;
    // Note: worker_cache_ptr will be null except if this session is using a
    // client-supplied ClusterDef (ClusterSpec propagation).
    std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
    std::unique_ptr<DeviceSet> device_set;
    // TODO(saeta): Convert to std::make_unique when available.
    std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
        new std::vector<std::unique_ptr<Device>>());

    if (req->config().has_cluster_def()) { // If a cluster is defined
      worker_cache_factory_options.cluster_def = &req->config().cluster_def(a);// Set the server_def's job_name and task_index fields.
      string normalized_string;
      string grpc_protocol(kGrpcProtocol);
      if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
          0) {
        normalized_string =
            req->target().substr(grpc_protocol.length(), string::npos);
      } else {
        normalized_string = req->target(a); }for (auto&& job : req->config().cluster_def().job()) {
        for (auto&& task : job.tasks()) {
          if (task.second == normalized_string) {
            if(worker_cache_factory_options.job_name ! =nullptr) {
              return;
            }
            if (env_->local_devices[0] - >parsed_name().job == job.name() &&
                env_->local_devices[0] - >parsed_name().task == task.first) {
              return;
            }
            worker_cache_factory_options.job_name = &job.name(a); worker_cache_factory_options.task_index = task.first; } } } worker_cache_factory_options.rpc_options = &req->config().rpc_options(a);// Create the worker cache from the computed server_def.
      status = env_->worker_cache_factory(worker_cache_factory_options,
                                          &worker_cache);
      if(! status.ok()) return;
      worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
      // Ping all the workers and build the list of devices that the
      // session will use.
      // Get the device
      status =
          DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
                                         worker_cache, remote_devices.get());
      if(! status.ok()) return;
      device_set.reset(new DeviceSet);
      for (auto&& d : *remote_devices) {
        device_set->AddDevice(d.get());
        DeviceNameUtils::ParsedName name = d->parsed_name(a);if(name.job == *worker_cache_factory_options.job_name && name.task == worker_cache_factory_options.task_index && name.type = ="CPU" && name.id == 0) {
          device_set->set_client_device(d.get()); }}}else { // There is no cluster
      worker_cache = env_->worker_cache;
      // Ping all the workers and build the list of devices that the
      // session will use.
      // Obtain the remote device
      status =
          DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
                                         worker_cache, remote_devices.get());
      if(! status.ok()) return;
      device_set.reset(new DeviceSet);
      for (auto&& d : *remote_devices) {
        device_set->AddDevice(d.get());
      }
      int num_local_devices = 0;
      for (Device* d : env_->local_devices) {
        device_set->AddDevice(d);
        if (num_local_devices == 0) {
          // Uses the first local device as the client device.
          device_set->set_client_device(d);
        }
        num_local_devices++;
      }
    }

    SessionOptions options;
    options.config = req->config(a);// Get the remote worker
    std::vector<string> filtered_worker_list;
    DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
                                   worker_cache, &filtered_worker_list);

    // Find the session via factory
    MasterSession* session = env_->master_session_factory(
        options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
        std::move(device_set), std::move(filtered_worker_list));

    GraphDef* gdef =
        const_cast<CreateSessionRequest*>(req)->mutable_graph_def(a);// Set up a session and pass the graph to the session
    status = session->Create(std::move(*gdef), worker_cache_factory_options);
    if(! status.ok()) {
      session->Close().IgnoreError(a); session->Unref(a);return;
    }
    resp->set_session_handle(session->handle());
    // Insert into the session map, which takes ownership of the session.
    {
      mutex_lock l(mu_);
      CHECK(sessions_.insert({session->handle(), session}).second); }}); }Copy the code

3. MasterSession

A MasterSession resides on the Master. Multiple clients may access the same Master at the same time. The Master constructs a MasterSession for each Client. MasterSession Controls the lifetime of Master sessions.

3.1 define

MasterSession is defined as follows.

// MasterSession wraps ClientGraph in a reference counted object.
// This way, MasterSession can clear up the cache mapping Run requests to
// compiled graphs while the compiled graph is still being used.
class MasterSession::ReffedClientGraph : public core::RefCounted {
 public:
  ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
                    std::unique_ptr<ClientGraph> client_graph,
                    const SessionOptions& session_opts,
                    const StatsPublisherFactory& stats_publisher_factory,
                    bool is_partial, WorkerCacheInterface* worker_cache,
                    bool should_deregister)
      : session_handle_(handle),
        bg_opts_(bopts),
        client_graph_before_register_(std::move(client_graph)),
        session_opts_(session_opts),
        is_partial_(is_partial),
        callable_opts_(bopts.callable_options),
        worker_cache_(worker_cache),
        should_deregister_(should_deregister),
        collective_graph_key_(
            client_graph_before_register_->collective_graph_key) {
    VLOG(1) < <"Created ReffedClientGraph for node with "
            << client_graph_before_register_->graph.num_node_ids(a); stats_publisher_ =stats_publisher_factory(handle, bopts, session_opts);

    // Initialize a name to node map for processing device stats.
    for (Node* n : client_graph_before_register_->graph.nodes()) {
      name_to_node_details_.emplace(
          n->name(),
          NodeDetails(n->type_string(),
                      strings::StrCat(
                          "(", absl::StrJoin(n->requested_inputs(), ",")))); }} -ReffedClientGraph(a)override {
    if (should_deregister_) {
      DeregisterPartitions(a); }else {
      for (Part& part : partitions_) {
        worker_cache_->ReleaseWorker(part.name, part.worker); }}}private:
  const string session_handle_;
  const BuildGraphOptions bg_opts_;

  // NOTE(mrry): This pointer will be null after `RegisterPartitions()` returns.
  std::unique_ptr<ClientGraph> client_graph_before_register_ TF_GUARDED_BY(mu_);
  const SessionOptions session_opts_;
  const bool is_partial_;
  const CallableOptions callable_opts_;
  WorkerCacheInterface* const worker_cache_;  // Not owned.

  struct NodeDetails {
    explicit NodeDetails(string type_string, string detail_text)
        : type_string(std::move(type_string)),
          detail_text(std::move(detail_text)) {}
    const string type_string;
    const string detail_text;
  };
  std::unordered_map<string, NodeDetails> name_to_node_details_;

  const bool should_deregister_;
  const int64_t collective_graph_key_;
  std::atomic<int64_t> execution_count_ = {0};

  // Graph partitioned into per-location subgraphs.
  struct Part {
    // Worker name.
    string name;

    // Maps feed names to rendezvous keys. Empty most of the time.
    std::unordered_map<string, string> feed_key;

    // Maps rendezvous keys to fetch names. Empty most of the time.
    std::unordered_map<string, string> key_fetch;

    // The interface to the worker. Owned.
    WorkerInterface* worker = nullptr;

    // After registration with the worker, graph_handle identifies
    // this partition on the worker.
    string graph_handle;

    Part() : feed_key(3), key_fetch(3) {}};// partitions_ is immutable after RegisterPartitions() call
  // finishes. RunPartitions() can access partitions_ safely without
  // acquiring locks.
  std::vector<Part> partitions_;

  mutable mutex mu_;

  // Partition initialization and registration only needs to happen
  // once. `! client_graph_before_register_ && ! init_done_.HasBeenNotified()`
  // indicates the initialization is ongoing.
  Notification init_done_;

  // init_result_ remembers the initialization error if any.
  Status init_result_ TF_GUARDED_BY(mu_);

  std::unique_ptr<StatsPublisherInterface> stats_publisher_;
};
Copy the code

3.2 create

MasterSession::Create(graph_def) works like this:

  • Call MakeForBaseGraph to initialize the calculation chart, and generate SimpleGraphExecutionState instance;
  • CreateWorkerSessions is called to broadcast notifications to all workers to create corresponding WorkerSessions if the cluster is dynamically configured.
Status MasterSession::Create(GraphDef&& graph_def,
                             const WorkerCacheFactoryOptions& options) {
  if (session_opts_.config.use_per_session_threads() ||
      session_opts_.config.session_inter_op_thread_pool_size(a) >0) {
    return errors::InvalidArgument(
        "Distributed session does not support session thread pool options.");
  }
  if (session_opts_.config.graph_options().place_pruned_graph()) {
    session_opts_.config.mutable_graph_options() - >set_place_pruned_graph(false);
  }

  GraphExecutionStateOptions execution_options;
  execution_options.device_set = devices_.get(a); execution_options.session_options = &session_opts_; {mutex_lock l(mu_);
    TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
        std::move(graph_def), execution_options, &execution_state_));
  }
  should_delete_worker_sessions_ = true;
  return CreateWorkerSessions(options);
}
Copy the code

3.2.1 Creating a Calculation diagram

This is going to build GraphExecutionState, and build the corresponding FullGraph based on GraphDef.

GraphDef is the raw Graph structure, ConvertGraphDefToGraph does the formatting conversion from GraphDef, which contains the metadata of the Graph, to Graph, which contains other information about the Graph structure, which is used by the runtime system.

/* static */ Status GraphExecutionState::MakeForBaseGraph(
    GraphDef&& graph_def, const GraphExecutionStateOptions& options,
    std::unique_ptr<GraphExecutionState>* out_state) {

  auto flib_def = absl::make_unique<FunctionLibraryDefinition>(
      OpRegistry::Global(), graph_def.library());

  TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&graph_def, *flib_def, 0));

  if (options.session_options->config.graph_options().place_pruned_graph() | |! options.session_options->config.experimental().optimize_for_static_graph()) {
    auto ret = absl::WrapUnique(new GraphExecutionState(
        absl::make_unique<GraphDef>(std::move(graph_def)), std::move(flib_def),
        options));

    // When place_pruned_graph is true, a different Graph* will be initialized
    // each time we prune the original graph, so there is no need to
    // construct a Graph* in this case.
    if(! options.session_options->config.graph_options().place_pruned_graph()) {
      auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
      TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *ret->original_graph_def_,
                                                base_graph.get()));
      TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
    }
    *out_state = std::move(ret);
  } else {
    auto ret = absl::WrapUnique(
        new GraphExecutionState(nullptr, std::move(flib_def), options));
    auto base_graph = absl::make_unique<Graph>(OpRegistry::Global());
    TF_RETURN_IF_ERROR(
        ConvertGraphDefToGraph({}, std::move(graph_def), base_graph.get()));
    TF_RETURN_IF_ERROR(ret->InitBaseGraph(std::move(base_graph)));
    *out_state = std::move(ret);
  }
  return Status::OK(a); }Copy the code

InitBaseGraph calls placer.run to do operator choreography. It is to put the operator in the calculation graph on the most suitable device to calculate, so as to maximize efficiency. Placer will analyze the Graph and fine-tune how each Node is placed according to the user’s requirements in four ways:

  • Try to meet the requirements of users. Users can specify devices based on device information or LOC.
  • Use fast equipment whenever possible. Each device in the TF system has a priority. The higher the priority is, the better the computing performance is.
  • Try to make the program run. If a Node specifies execution on a device that is not in the system, an available device is selected to override the Placement.
  • Try to consider the nearest neighbor. For example, try to keep consumers and producers on the same device to avoid pointless copying across devices.
Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
  // Save stateful placements before placing.
  RestoreStatefulNodes(new_graph.get());

  GraphOptimizationPassOptions optimization_options;
  optimization_options.session_handle = session_handle_;
  optimization_options.session_options = session_options_;
  optimization_options.graph = &new_graph;
  optimization_options.flib_def = flib_def_.get(a); optimization_options.device_set = device_set_;TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global() - >RunGrouping(
      OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));

  Placer placer(new_graph.get(), "", flib_def_.get(), device_set_,
                /* default_local_device= */ nullptr,
                session_options_ == nullptr|| session_options_->config.allow_soft_placement(), session_options_ ! =nullptr &&
                    session_options_->config.log_device_placement());
  TF_RETURN_IF_ERROR(placer.Run());

  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global() - >RunGrouping(
      OptimizationPassRegistry::POST_PLACEMENT, optimization_options));

  for (const Node* n : new_graph->nodes()) {
    node_name_to_cost_id_map_[n->name()] = n->cost_id(a); }SaveStatefulNodes(new_graph.get());
  graph_ = new_graph.release(a);return Status::OK(a); }Copy the code

3.2.2 create WorkerSession

After MasterSession is successfully created, if the cluster is not dynamically configured (the default distributed configuration environment), all workers will not be broadcast to dynamically create WorkerSession. In fact, every Worker has an instance of SessionMgr, which holds an instance of WorkerSession named Legacy_session_. Therefore, each Worker has a globally unique WorkerSession instance.

Figure 3 creating a WorkerSession

The logic is as follows:

  • First, the ReleaseWorker is called to release the existing workers.
  • Second, call GetOrCreateWorker to retrieve the Worker from the cache again. If not, the cache will be built.
  • Finally, iterate over Workers and call CreateWorkerSessionAsync to have each Worker create a WorkerSession. Each request will use set_session_handle(handle_) to set the MasterSession session_handle into, Thus, each WorkerSession shares the same session_handle as the MasterSession, and they belong to the same MasterSession.

In order to collect all the messages returned by Workers, the counter BlockingCounter is used to wait, which sets the initial value to the number of Workers, When the collection of all the Workers CreateWorkerSessionResponse response message, counter will reduce to zero, BlockingCounter will be awakened.

Status MasterSession::CreateWorkerSessions(
    const WorkerCacheFactoryOptions& options) {
  const std::vector<string> worker_names = filtered_worker_list_;
  WorkerCacheInterface* worker_cache = get_worker_cache(a);struct WorkerGroup {
    // The worker name. (Not owned.)
    const string* name;

    // The worker referenced by name. (Not owned.)
    WorkerInterface* worker = nullptr;

    // Request and responses used for a given worker.
    CreateWorkerSessionRequest request;
    CreateWorkerSessionResponse response;
    Status status = Status::OK(a); };BlockingCounter done(worker_names.size());
  std::vector<WorkerGroup> workers(worker_names.size());

  // Release the workers.
  auto cleanup = gtl::MakeCleanup([&workers, worker_cache] {
    for (auto&& worker_group : workers) {
      if(worker_group.worker ! =nullptr) {
        worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker); }}}); string task_name; string local_device_name; DeviceNameUtils::SplitDeviceName(devices_->client_device() - >name(),
                                   &task_name, &local_device_name);
  const int64_t client_device_incarnation =
      devices_->client_device() - >attributes().incarnation(a); Status status = Status::OK(a);// Create all the workers & kick off the computations.
  for (size_t i = 0; i < worker_names.size(a); ++i) { workers[i].name = &worker_names[i]; workers[i].worker = worker_cache->GetOrCreateWorker(worker_names[i]);
    workers[i].request.set_session_handle(handle_);
    workers[i].request.set_master_task(task_name);
    workers[i].request.set_master_incarnation(client_device_incarnation);
    if (session_opts_.config.share_cluster_devices_in_session() ||
        session_opts_.config.experimental().share_cluster_devices_in_session()) {
      for (const auto& remote_dev : devices_->devices()) {
        *workers[i].request.add_cluster_device_attributes() =
            remote_dev->attributes(a); }if(! session_opts_.config.share_cluster_devices_in_session() &&
          session_opts_.config.experimental().share_cluster_devices_in_session()) {
      }
    }

    DeviceNameUtils::ParsedName name;
    if(! DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
      status = errors::Internal("Could not parse name ", worker_names[i]);
      return status;
    }
    if(! name.has_job || ! name.has_task) { status = errors::Internal("Incomplete worker name ", worker_names[i]);
      return status;
    }

    if (options.cluster_def) {
      *workers[i].request.mutable_server_def() - >mutable_cluster() =
          *options.cluster_def;
      workers[i].request.mutable_server_def() - >set_protocol(*options.protocol);
      workers[i].request.mutable_server_def() - >set_job_name(name.job);
      workers[i].request.mutable_server_def() - >set_task_index(name.task);
      // Session state is always isolated when ClusterSpec propagation
      // is in use.
      workers[i].request.set_isolate_session_state(true);
    } else {
      // NOTE(mrry): Do not set any component of the ServerDef,
      // because the worker will use its local configuration.
      workers[i].request.set_isolate_session_state(
          session_opts_.config.isolate_session_state());
    }
    if (session_opts_.config.experimental().share_session_state_in_clusterspec_propagation()) {
      // In a dynamic cluster, the ClusterSpec info is usually propagated by
      // master sessions. However, in data parallel training with multiple
      // masters
      // ("between-graph replication"), we need to disable isolation for
      // different worker sessions to update the same variables in PS tasks.
      workers[i].request.set_isolate_session_state(false); }}for (size_t i = 0; i < worker_names.size(a); ++i) {auto cb = [i, &workers, &done](const Status& s) {
      workers[i].status = s;
      done.DecrementCount(a); }; workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
                                                &workers[i].response, cb);
  }

  done.Wait(a);for (size_t i = 0; i < workers.size(a); ++i) { status.Update(workers[i].status);
  }
  return status;
}
Copy the code
GrpcRemoteWorker

GrpcRemoteWorker is the client of gRPC and invokes the service interface of the remote WorkerService through the stub.

void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                              CreateWorkerSessionResponse* response,
                              StatusCallback done) override {
  IssueRequest(request, response, createworkersession_, std::move(done));
}
Copy the code
GrpcWorkerService

The remote Worker, receives the news is in GrpcWorkerService, when CreateWorkerSessionRequest message is received, will be handled by CreateWorkerSessionHandler callback, CreateWorkerSessionHandler is a macro, its start a working in the thread pool threads, triggering the Worker (CreateWorkerSession GrpcWorker) methods to dynamically create WorkerSession instance.

#define HANDLE_CALL(method, may_block_on_compute_pool)                        \
  void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
    auto closure = [this, call]() {                                           \
      Status s = worker_->method(&call->request, &call->response);            \
      if(! s.ok()) { \ VLOG(3) <<"Bad response from " << #method << ":"<< s; \ } \ call->SendResponse(ToGrpcStatus(s)); The \}; \if((may_block_on_compute_pool)) { \ worker_->env()->env->SchedClosure(std::move(closure)); The \}else{ \ worker_->env()->compute_pool->Schedule(std::move(closure)); \ } \ ENQUEUE_REQUEST(method, false); The \}

  HANDLE_CALL(CreateWorkerSession, false);
Copy the code

4. WorkerSession

Actually, GrpcWorker final call is WorkerInterface CreateWorkerSession method.

Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
                           CreateWorkerSessionResponse* response) {
  return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
}
Copy the code

CreateWorkerSessionRequest message carried session_handle MasterSession distribution, GrpcWorker will create a WorkerSession accordingly, Session_handle uniquely identifies the WorkerSession within the Worker.

Within the context of the WorkerEnv GrpcWorker is a SessionMgr, which manages and maintains all WorkerSession lifecycles uniformly. SessionMgr and WorkerSession have a one-to-many relationship, and each WorkerSession instance is identified by session_HANDLE.

void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                      CreateWorkerSessionResponse* response,
                                      StatusCallback done) {
  Status s = env_->session_mgr->CreateSession(
      request->session_handle(), request->server_def(),
      request->cluster_device_attributes(), request->isolate_session_state(),
      request->master_task(), request->master_incarnation());
  done(s);
}
Copy the code

4.1 SessionMgr

4.4.1 definition

The key points are as follows: The corresponding relationship between session_Handle and WorkerSession is maintained. Each WorkerSession is identified by session_Handle.

  • STD ::map<string, STD :: shareD_ptr > sessionS_ : Maintains mappings.

  • STD :: shareD_ptr LegacY_session_ : local WorkerSession instance.

Figure 4 SessionMgr

class SessionMgr {
 public:
  typedef std::function<Status(const ServerDef&, WorkerCacheInterface**)>
      WorkerCacheFactory;

  explicit SessionMgr(
      WorkerEnv* worker_env, const string& default_worker_name,
      std::unique_ptr<WorkerCacheInterface> default_worker_cache,
      WorkerCacheFactory worker_cache_factory);
  ~SessionMgr() {}

  // Allocates state for a new session.
  Status CreateSession(const string& session, const ServerDef& server_def,
                       bool isolate_session_state);
  Status CreateSession(
      const string& session, const ServerDef& server_def,
      const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
      bool isolate_session_state);

  // Create WorkerSession from the master with the given `master_task` and
  // `master_incarnation`. We first look for existing WorkerSessions associated
  // with the specified master task. If there are sessions created by the same
  // master but with a different incarnation, it indicates that the remote
  // master has restarted before deleting the sessions on worker. When it
  // happens, old sessions associated with the master will be automatically
  // removed before the new session is created.
  Status CreateSession(
      const string& session, const ServerDef& server_def,
      const protobuf::RepeatedPtrField<DeviceAttributes>& device_attributes,
      bool isolate_session_state, string master_task,
      int64_t master_incarnation);

  void ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache);

  // Updates state (worker cache, devices) of worker session identified by
  // session name (`session`) based on a new server_def and set of devices.
  Status UpdateSession(const string& session, const ServerDef& server_def,
                       const protobuf::RepeatedPtrField<DeviceAttributes>&
                           cluster_device_attributes,
                       bool isolate_session_state);

  // Locates the worker session for a given session handle
  Status WorkerSessionForSession(const string& session_handle,
                                 std::shared_ptr<WorkerSession>* out_session);
  std::shared_ptr<WorkerSession> LegacySession(a);

  Status DeleteSession(const string& session);

  static string WorkerNameFromServerDef(const ServerDef& server_def);

  void SetLogging(bool active);

  void RetrieveLogs(int64_t step_id, LoggingResponse* response);

  void ClearLogs(a);

 private:
  WorkerEnv* const worker_env_;  // Not owned.

  // A note about destruction:
  // We must delete graph_mgr before device_mgr, due to shared
  // ownership of OpKernels in the executors. (The graph_mgr will
  // free all stateless OpKernels, and pass over borrowed stateful
  // OpKernels, which are also held in their respective devices'
  // OpSegments.)
  //
  // legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
  // that sessions_'s WorkerSessions are deleted (which do not own the
  // underlying devices, but instead own RenamedDevices) before
  // legacy_session_ is deleted. Further, we must ensure that WorkerSession's
  // device_mgr is deleted after WorkerSession's graph_mgr.

  std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
  std::shared_ptr<WorkerSession> legacy_session_;

  bool is_logging_active_ = false;

  const WorkerCacheFactory worker_cache_factory_;

  Status WorkerSessionForSessionLocked(
      const string& session_handle, std::shared_ptr<WorkerSession>* out_session)
      TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);

  mutex mu_;
  // A map from session identifier to internal session structure.
  std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);

  // Incarnation and WorkerSession handle associated with a master task.
  struct MasterAssociatedSession {
    const int64_t master_incarnation;
    const string session_handle;
  };
  // A map from master task name to its associated worker sessions.
  std::unordered_multimap<string, MasterAssociatedSession>
      master_to_associated_sessions_ TF_GUARDED_BY(mu_);
};
Copy the code

4.1.2 to establish a Session

The CreateSession method creates the WorkerSession and GraphMgr.

Status SessionMgr::CreateSession(
    const string& session, const ServerDef& server_def,
    const protobuf::RepeatedPtrField<DeviceAttributes>&
        cluster_device_attributes,
    bool isolate_session_state, string master_task,
    int64_t master_incarnation) {
  mutex_lock l(mu_);
  if (session.empty()) {
    return errors::InvalidArgument("Session must be non-empty.");
  }

  // For given master task name, check if one or more `WorkerSession`s have been
  // created previously on this worker, and if so garbage collect the expired
  // `WorkerSession`s. This happens when the master fails before sending
  // `DeleteSession` requests, which can cause `WorkerSession`s to be leaked.
  if(! master_task.empty()) {
    auto it_range = master_to_associated_sessions_.equal_range(master_task);
    if(it_range.first ! = it_range.second && it_range.first->second.master_incarnation ! = master_incarnation) {auto it = it_range.first;
      while(it ! = it_range.second) {auto session_it = sessions_.find(it->second.session_handle);
        if(session_it ! = sessions_.end()) {
          sessions_.erase(session_it);
        }
        it = master_to_associated_sessions_.erase(it);
      }
    }
  }

  WorkerCacheInterface* worker_cache = nullptr;
  string worker_name;
  if (server_def.cluster().job().empty()) {
    worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
    worker_name = legacy_session_->worker_name(a); }else {
    TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
    worker_name = WorkerNameFromServerDef(server_def);
  }

  if(worker_cache ! =nullptr&& default_worker_cache_ ! =nullptr) {
    worker_cache->SetLogging(this->is_logging_active_);
  }

  std::shared_ptr<WorkerSession> worker_session;
  std::vector<std::unique_ptr<Device>> cluster_devices;

  if (isolate_session_state || server_def.cluster().job_size()) {

    // Create a private copy of the DeviceMgr for the WorkerSession.
    std::vector<std::unique_ptr<Device>> renamed_devices;
    for (Device* d : worker_env_->local_devices) {
      renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
          worker_name, d, false, isolate_session_state));
    }
    auto device_mgr = MakeUnique<StaticDeviceMgr>(std::move(renamed_devices));
    LookupLocalDevice cb = [&device_mgr](StringPiece name, Device** device) {
      return device_mgr->LookupDevice(name, device);
    };
    AsRemoteDevices(worker_env_->env, cluster_device_attributes, cb,
                    &cluster_devices);
    std::unique_ptr<DynamicDeviceMgr> remote_devices;
    if(! cluster_device_attributes.empty()) {
      remote_devices = MakeUnique<DynamicDeviceMgr>();
      TF_RETURN_IF_ERROR(
          remote_devices->AddDevices(std::move(cluster_devices)));
    }

    auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
    worker_session.reset(
        new WorkerSession(session, worker_name,
                          std::unique_ptr<WorkerCacheInterface>(worker_cache),
                          std::move(device_mgr), std::move(graph_mgr),
                          std::move(remote_devices)));
  } else {
    AsRemoteDevices(worker_env_->env, cluster_device_attributes, nullptr,
                    &cluster_devices);
    std::unique_ptr<DynamicDeviceMgr> remote_devices;
    if(! cluster_device_attributes.empty()) {
      remote_devices = MakeUnique<DynamicDeviceMgr>();
      TF_RETURN_IF_ERROR(
          remote_devices->AddDevices(std::move(cluster_devices)));
    }
    // Borrow the WorkerEnv's DeviceMgr for the WorkerSession, so
    // that resources using it can use its devices after the
    // WorkerSession has been deleted.
    auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, worker_env_->device_mgr);
    worker_session = WorkerSession::CreateWithBorrowedDeviceMgr(
        session, worker_name,
        std::unique_ptr<WorkerCacheInterface>(worker_cache),
        worker_env_->device_mgr, std::move(graph_mgr),
        std::move(remote_devices));
  }

  sessions_.insert(std::make_pair(session, std::move(worker_session)));
  if(! master_task.empty()) {
    MasterAssociatedSession s{master_incarnation, session};
    master_to_associated_sessions_.emplace(master_task, s);
  }
  return Status::OK(a); }Copy the code

4.1.3 registered figure

Let’s use RegisterGraphAsync as an example to look at the internal functions of worker. You can see that the basic functionality is done using GraphMgr.

void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
                                RegisterGraphResponse* response,
                                StatusCallback done) {
  std::shared_ptr<WorkerSession> session;
  Status s;
  if (request->create_worker_session_called()) {
    s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
                                                   &session);
  } else {
    session = env_->session_mgr->LegacySession(a); }if (s.ok()) {
    s = session->graph_mgr() - >Register(
        request->session_handle(), request->graph_def(), session.get(),
        request->graph_options(), request->debug_options(),
        request->config_proto(), request->collective_graph_key(),
        session->cluster_flr(), response->mutable_graph_handle());
  }
  done(s);
}
Copy the code

4.2 WorkerSession

Definition 2

Several important member variables of WorkerSession include the management classes GraphMgr, DeviceMgr, DynamicDeviceMgr:

  • String sESSION_NAME_ : Session name.

  • String worker_name_ : name of the Worker, such as/job: mnist/up: 0 / task: 1.

  • STD ::shared_ptr worker_cache_ : Worker cache.

  • STD ::unique_ptr graph_mgr_ : Computing graphs registered in this session. Each Worker can register and run multiple computing graphs, and each graph is identified by graph) Handle.

  • STD :: unique_pTR Device_mgr_ : Local computing device set information.

Figure 5 WorkerSession concept

// WorkerSession encapsulates all of the state relating to a given session.
class WorkerSession {
 public:
  // Collection of local devices. These devices are typically
  // RenamedDevices in all except the SessionMgr.legacy_session_ and
  // sessions created with `isolate_session_state == false`. In the
  // those cases, this method returns a pointer to a borrowed
  // DeviceMgr (typically the `worker_env.device_mgr`).
  DeviceMgr* device_mgr(a) {
    return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_;
  }

  DynamicDeviceMgr* remote_device_mgr(a) { return remote_device_mgr_.get(a); }const string& session_name(a) const { return session_name_; }
  const string& worker_name(a) const { return worker_name_; }

  WorkerCacheInterface* worker_cache(a) const {
    tf_shared_lock l(worker_session_state_mu_);
    return worker_cache_.get(a); }GraphMgr* graph_mgr(a) const { return graph_mgr_.get(a); }ClusterFunctionLibraryRuntime* cluster_flr(a) const {
    return cluster_flr_.get(a); }WorkerSession(const string& session_name, const string& worker_name,
                std::unique_ptr<WorkerCacheInterface> worker_cache,
                std::unique_ptr<DeviceMgr> device_mgr,
                std::unique_ptr<GraphMgr> graph_mgr,
                std::unique_ptr<DynamicDeviceMgr> remote_device_mgr);

  static std::shared_ptr<WorkerSession> CreateWithBorrowedDeviceMgr(
      const string& session_name, const string& worker_name,
      std::unique_ptr<WorkerCacheInterface> worker_cache,
      DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
      std::unique_ptr<DynamicDeviceMgr> remote_device_mgr);

  // In the eager runtime we allow WorkerSession to be updated, where the
  // worker cache will be recreated. If WorkerSession upate is expected and a
  // worker in the cache is used in RPCs, the caller should hold a shared
  // pointer to avoid the workers getting deleted.
  std::shared_ptr<WorkerCacheInterface> GetSharedWorkerCache(a) {
    tf_shared_lock l(worker_session_state_mu_);
    return worker_cache_;
  }

  // Update an existing worker session with new set of remote workers and
  // devices. Added devices will be owned by the worker session, and removed
  // devices will be freed by their names.
  Status UpdateWorkerCacheAndDevices(
      std::unique_ptr<WorkerCacheInterface> new_worker_cache,
      std::vector<std::unique_ptr<Device>> added_remote_devices,
      const std::vector<Device*>& removed_remote_devices);

  ~WorkerSession(a);private:
  WorkerSession(const string& session_name, const string& worker_name,
                std::unique_ptr<WorkerCacheInterface> worker_cache,
                DeviceMgr* borrowed_device_mgr,
                std::unique_ptr<GraphMgr> graph_mgr,
                std::unique_ptr<DynamicDeviceMgr> remote_device_mgr);

  // The name of the session.
  const string session_name_;

  // The name of the worker. E.g., /job:mnist/replica:0/task:1.
  const string worker_name_;

  mutable mutex worker_session_state_mu_;
  // Object from which WorkerInterface instances can be obtained.
  std::shared_ptr<WorkerCacheInterface> worker_cache_
      TF_GUARDED_BY(worker_session_state_mu_);

  // graph_mgr keeps track of the registered graphs of this session.
  //
  // Note: graph_mgr must be deleted before rendezvous_mgr!
  // Note: graph_mgr must be deleted before device_mgr!
  const std::unique_ptr<GraphMgr> graph_mgr_;

  std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;

  const std::unique_ptr<DeviceMgr> device_mgr_;
  DeviceMgr* const borrowed_device_mgr_;  // Not owned.
  std::unique_ptr<DynamicDeviceMgr> remote_device_mgr_;
};
Copy the code

So far, the basic process of session has been sorted out, and the services will be analyzed in detail below.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

Placer, the Placement heuristic algorithm module in TensorFlow