0 x00 the
In this article, PyTorch on Horovod is used as a starting point to analyze the recovery process of Horovod resilience training.
Differences between ElasticSampler and PyTorch native DistributedSampler, Horovod resilience training how to recover, etc.
Links to other articles in this series are as follows:
Deep learning distributed training framework Horovod (1) — basic knowledge
Deep learning distributed training framework Horovod (2) — from the user’s perspective
Deep learning distributed training framework Horovod (3) — What is behind Horovodrun
Deep learning distributed training framework Horovod (4) — Network foundation & Driver
Deep learning distributed training framework Horovod (5) — fusion framework
Deep learning distributed training framework Horovod (6) — background architecture
Deep learning Distributed Training framework Horovod (7) — DistributedOptimizer
Deep learning distributed training framework Horovod (8) — on Spark
Deep learning distributed training framework Horovod (9) — start on Spark
Deep learning distributed training framework Horovod (10) — Run on Spark
Deep learning distributed training framework Horovod (11) — on Spark — GLOO scheme
Deep learning distributed Training framework Horovod (12) — elastic training overall architecture
Deep learning distributed training framework Horovod (13) — Driver of elastic training
Deep learning distributed training framework Horovod (14) — Elastic training discovery node & State
Deep learning distributed training framework Horovod (15) — broadcast & notification
Deep learning Distributed training Framework Horovod (16) — Elastic training Worker life cycle
Deep learning distributed training framework Horovod (17) — Fault tolerance of elastic training
Deep learning distributed training framework Horovod (18) — Kubeflow TF-operator
Deep learning distributed training framework Horovod (19) — Kubeflow Mpi-operator
0 x01 general
This article originated from a brother’s message:
In elastic training, if the number of nodes changes, how can the data be redivided? For example, if an epoch has not been completed and new nodes are added and new data is redivided, will the model trained with the old data in the current memory still be valid?
I happened to have a similar question when analyzing the PyTorch distribution, so I went back to see how Horovod was implemented.
Most of our previous analyses and examples of Horovod have used TensorFlow as an example. You should have a general idea of the overall logic and thinking of how the various frameworks fit into Horovod, so this section will focus on looking at some of the special aspects of PyTorch.
Another reason to use PyTorch is that the parts of the recovery training process that PyTorch relates to are relatively clear.
In the horovod/ Torch/Elastic/directory, there are two files: state.py and sampler.py. Since it’s an elastic correlation, let’s first look at what’s special about it.
0x02 Sampler
In horovod/torch/elastic/sampler. Py, have a ElasticSampler class, let’s take a look at the specific elastic against doing what processing.
Because the ElasticSampler class notes that its implementation is very similar to DistributedSampler, which is PyTorch’s native implementation, we’ll look at DistributedSampler first.
2.1 PyTorch Distributed Optimizer
2.1.1 definition
DistributedSampler code in: the torch/distributed/optim/optimizer. Py.
DistributedSampler is divided into consecutive num_replicas and distributed to num_replicas processes. In this way, the num_replicas processes do not overlap and cross each num_replicas data.
One technical detail of the __iter__ code is how does the worker iterate?
indices = indices[self.rank:self.total_size:self.num_replicas]
In this case, the num_replicas is the total number of ranks, starting with self.rank and ending with the total number of ranks, incremented by the num_replicas (world size). Therefore, each worker will strictly return the part of data serial number corresponding to its rank.
Let’s use an example like this:
A =,2,3,4,5,6,7,8,9,10,11,12,13,14,15 [1] print (a [0:15:3]) print (a [1:15:3]) print (a) [2:15:3]Copy the code
Get:
[1, 4, 7, 10, 13], [2, 5, 8, 11, 14] [3, 6, 9, 12, 15]Copy the code
The specific code is as follows:
class DistributedSampler(Sampler[T_co]): def __iter__(self) -> Iterator[T_co]: if self.shuffle: # Shuffle if needed, # deterministically shuffle based on epoch and seed G = torch.Generator() g.manual_seed(self.seed +) self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: Indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] assert len(indices) == Self. total_size # subsample # Indices = indices[self.rank:self.total_size:self.num_replicas] Assert Len (indices) == self.num_samples Def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epochCopy the code
2.1.2 trouble spots
DistributedSampler is a DistributedSampler with some problems. Let’s analyze the following problems:
-
If the user has gone through five rounds of training, it means that the data from the first five batches has been used. Assuming that a new worker node is added at this point, the training should resume. The data from the first five batches that have already been used should not be used for training again.
- Q1: After resuming training, how do I remove the processed data index?
-
If we add or subtract nodes, if we tell Sampler that we need to change the extraction rules, at the very least, the num_replicas need to be updated, and the new num_replicas need to be used in the extraction, for example, the original num_replicas = 5, now the num_replicas = 5, now it has 6 nodes, Num_replicas should be 6.
- Problem 2: When to call after resuming training
__iter__
For new training? - Question 3: When do I change the num_replicas after resuming the training?
- Problem 2: When to call after resuming training
If we look at the DistributedSampler, we can see that its __iter__ has no information about saving state. That is, if the training is restarted, it will still be extracted from the total data, not from the remaining data. Nor have they found solutions to the latter two problems.
As a result, it is difficult to do elastic training with DistributedSampler, so Horovod uses ElasticSampler to solve this problem.
2.2 ElasticSampler
2.2.1 definition
As you can see from the comments, ElasticSampler claims to be very similar to DistributedSampler. A subsequent code comparison of the two classes shows that the functionality is basically the same.
But there are two new variables worth noting, namely:
self.processed_indices = set()
self.remaining_indices = []
Copy the code
The definition is as follows:
import math import random import torch.utils.data.distributed from horovod.torch.mpi_ops import rank, size class ElasticSampler(torch.utils.data.Sampler): """Sampler that partitions dataset across ranks and repartitions after reset events. Works similar to `DistributedSampler`, but with an optional capability to record which dataset indices have been processed each batch. When tracked by a `TorchState` object, the sampler will automatically repartition the unprocessed indices among the new set of workers. In order to use this object successfully it is recommended that the user: 1. Include this object in the `TorchState`. 2. Call `record_batch` or `record_indices` after processing a set of samples. 3. Call `set_epoch` at the end of each epoch to clear the processed indices. Args: dataset: Dataset used for sampling (assumed to be of constant size). shuffle: If `True` (default), shuffle the indices. seed: Random seed used to shuffle the sampler when `shuffle=True`. This number should be identical across all ranks (default: 0). """ def __init__(self, dataset, shuffle=True, seed=0): self.dataset = dataset self.shuffle = shuffle self.seed = seed self.epoch = 0 self.processed_indices = set() # Self.num_replicas = 0 self.rank = 0 self.remaining_indices = [] # Self.num_replicas = 0 self.total_size = 0 self.reset()Copy the code
2.2.2 Elastic Solution
The specific elasticity plan revolves around the two variables mentioned earlier.
2.2.2.1 General process
We recall from its comments how to use:
1. Include this object in the `TorchState`.
2. Call `record_batch` or `record_indices` after processing a set of samples.
3. Call `set_epoch` at the end of each epoch to clear the processed indices.
Copy the code
We can derive the internal logic:
-
Perform the epoch training.
-
Self. indices = self.remaining_indices[:] when using __iter__ to fetch the next batch of data will be extracted only from the untrained data.
-
After each batch of data is processed, the user uses record_Batch or record_indices to store the trained batch information in processed_indices. This records the data that has been trained.
-
If there is a problem, or a node change, then:
- The reset function is called, and reset returns the trained data
processed_indices
Removed from the total number of dataself.remaining_indice
There is no training data. - Recovery training, only untrained data extracted.
- The reset function is called, and reset returns the trained data
-
-
When the epoch is complete, set_EPOCH is called to reset processed_indices and the reset method is also called to clear the zero.
The specific function code is:
def set_epoch(self, epoch): """Sets the epoch for this sampler. When `shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Will clear and reset the `processed_indices` for the next epoch. It is important that this is called at the end of the epoch (not the beginning) to ensure that partially completed epochs do not reprocess samples. Args: epoch: Epoch number. """ self. Epoch = Epoch # Reset is not a problem because it is used for exception handling. self.processed_indices = set() self.reset() def record_batch(self, batch_idx, batch_size): """Record indices at batch `batch_idx` with length `batch_size` as processed.""" indices = set(self.get_indices(batch_idx, batch_size)) self.record_indices(indices) def record_indices(self, indices): "" self. Processed_data. Update (indices) # def get_data (self, batch_idx, batch_size): """Return list of indices at batch `batch_idx` with length `batch_size`.""" start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, len(self.indices)) return self.indices[start_idx:end_idx] def load_state_dict(self, state_dict): Self.epoch = state_dict['epoch'] self.processed_indices = state_dict['processed_indices'] # Retrieve self.reset() def from saved data state_dict(self): Return dict(#) for state. save Epoch =self.epoch, processed_indices=self.processed_indices) def reset(self): # size code in horovod/torch/mpi_ops.py, size = _basics. Size, Hvd.size () self.num_replicas = size() # re-configure several worker self.rank = rank() # Exclude any samples we have already Processed this epoch # Remove data that has already been trained Self. remaining_indices = [IDx for IDx in range(len(self.dataset)) if IDX not in Self.processed_indices] self.num_samples = int(math.ceil(len(self.remaining_indices) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas def __iter__(self): Self. Indices = self. Remaining_indices [:] # extracted from residual data if self. Shuffle: # Shuffle indices across workers deterministically in place seed = self.seed + self.epoch random.Random(seed).shuffle(self.indices) # add extra samples to make it evenly divisible self.indices += self.indices[:(self.total_size - len(self.indices))] assert len(self.indices) == self.total_size # subsample # How does this worker traverse? The starting index is self.rank, the ending index is the total data length, Self. Rank :self.total_size:self.num_replicas] Assert Len (self.indices) == Return iter(self.indices) def __len__(self): return self.num_samplesCopy the code
2.2.2.2 Exception Handling
In horovod/torch/elastic state. Py, when retraining, will call to ElasticSampler load_state_dict method.
Load_state_dict calls reset, which removes trained data and returns untrained data.
So retraining, within the epoch, will not be repeated with the data already trained.
We’ll examine this process in more detail later.
2.2.1 How to Use it
ElasticSampler use is as follows, the code is located in: examples/elastic/pytorch/pytorch_imagenet_resnet50_elastic py.
In this section, we will focus on how to use it, which is the normal usage/handling process. Exception handling will be introduced later, and some minor code will be omitted here.
2.2.1.1 Body code
The main notice of the body code is that two elastic samplers are configured using ElasticSampler.
if __name__ == '__main__': allreduce_batch_size = args.batch_size * args.batches_per_allreduce # Elastic Horovod: use ElasticSampler to partition data among workers. train_dataset = datasets.ImageFolder() train_sampler = HVD. Elastic. ElasticSampler (train_dataset) # configured with elastic sampling train_loader = torch. The utils. Data. The DataLoader (train_dataset, batch_size=allreduce_batch_size, sampler=train_sampler, * * kwargs) val_dataset = datasets. ImageFolder () val_sampler = HVD. Elastic. ElasticSampler (val_dataset) # configure the elastic sample val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.val_batch_size, sampler=val_sampler, **kwargs) # Set up standard ResNet-50 model. model = models.resnet50() # Horovod: scale learning rate by the number of GPUs. optimizer = optim.SGD(model.parameters(), lr=(args.base_lr * lr_scaler), momentum=args.momentum, weight_decay=args.wd) # Horovod: wrap optimizer with DistributedOptimizer. optimizer = hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), compression=compression, backward_passes_per_step=args.batches_per_allreduce, op=hvd.Adasum if args.use_adasum else hvd.Average, gradient_predivide_factor=args.gradient_predivide_factor) # Restore from a previous checkpoint, if initial_epoch is specified. # Horovod: restore on the first worker which will broadcast weights to other workers. state = hvd.elastic.TorchState(model=model, optimizer=optimizer, train_sampler=train_sampler, val_sampler=val_sampler, epoch=resume_from_epoch, batch=0) full_train(state)Copy the code
2.2.1.2 Training code
The following code is the specific training code.
def train(state): model.train() epoch = state.epoch batch_offset = state.batch with tqdm(total=len(train_loader), desc='Train Epoch #{}'.format(epoch + 1), disable=not verbose) as t: Index for idx, (data, target) in enumerate(train_loader): # Elastic Horovod: update the current batch index this epoch # and commit / check for host updates. Do not check hosts when # we commit as it would be redundant. state.batch = batch_idx = batch_offset + idx if args.batches_per_commit > 0 and \ state.batch % args.batches_per_commit == 0: state.commit() elif args.batches_per_host_check > 0 and \ state.batch % args.batches_per_host_check == 0: state.check_host_updates() adjust_learning_rate(epoch, batch_idx) optimizer.zero_grad() # Split data into sub-batches of size batch_size for i in range(0, len(data), args.batch_size): data_batch = data[i:i + args.batch_size] target_batch = target[i:i + args.batch_size] output = model(data_batch) train_accuracy.update(accuracy(output, target_batch)) loss = F.cross_entropy(output, target_batch) train_loss.update(loss) # Average gradients among sub-batches loss.div_(math.ceil(float(len(data)) / args.batch_size)) loss.backward() # Elastic Horovod: Record which samples were processed this batch # so we do not reprocess them if a reset event occurs # state.train_sampler.record_batch(idx, allreduce_batch_size) # Gradient is applied across all ranks optimizer.step() state.commit() def end_epoch(state): State.epoch += 1 state.batch = 0 state.train_sampler.set_epoch(state.epoch) @hvd.elastic.run def full_train(state): while state.epoch < args.epochs: Train (state) validate(state.epoch) save_checkpoint(state.epoch) end_epoch(stateCopy the code
The specific logic (normally handled) of an epoch is as follows:
- If the dataset is initially run, reset is called for initialization, which builds an index list based on the length of the dataset. Subtract processed_indices from this index list to get the data index that the epoch should process, assign to remaining_indices, which is the data index that should be processed;
- in
__iter__
Function, callself.indices = self.remaining_indices[:]
, so indices can be used for iterative extraction; - In the training function, call iter(indices) for iterative extraction, then call record_indices to update the index used this time to processed_indices. Processed_indices records all indexes currently in use;
- After the epoch ends, call set_EPOCH to reset processed_indices, call reset to reset remaining_indices;
+---------------------------------------------------------------+ | ElasticSampler | | | +--------------------------------------------> + | 4 | set_epoch | | | | | | | | | 1 | reset | | | | | | | | | | | v | | | | | | remaining_indices = dataset - processed_indices | | | | | | + | | | | | | | | | | | 2 | __iter_ | | | | | | | | | | | v | | | indices = remaining_indices[:] | | | + | | | | | | +---------------------------------------------------------------+ | | | 3 | | | | v | +--------------------------------------+------------------------------------+ | | train() train loop | | | | | | ----------------------------> iter(indices)+--------------------> | | | ^ | | | | | | | | | step() backward() | | | | +----------------------------------------+ | | | | | |record_indices | | | | | | | | | | | | <-------------+ processed_indices.update(indices) +------+ v | | | | | | | | +----------------------------------------+ | | | | | +---------------------------------------+-----------------------------------+ | | | | +-----------------------------------------------+Copy the code
0x03 Save and check periodically
3.1 Periodic Storage
Hovorod recommends that users periodically call state.mit () to back up state to memory.
- Regular backups are very useful. When unexpected errors occur to some workers, regular backup can avoid the failure to restore the scene during retraining because the status is damaged. For example, if a worker makes a sudden error in the process of updating parameters, part of gradients may be updated to half at this point, which is irreversible and cannot be continued. Therefore, when this state occurs, a HorovodInternalError is thrown, and when hvd.elastice. run catches this exception, all states are restored using the latest commit.
- Because the commit state is expensive (for example, if the number of parameters is too large, it takes too long), a balance needs to be struck between “the processing time per batch” and “how long ago the training needs to recover from the state if an error occurs”. If you commit every 10 batches, for example, you reduce the replication time by a factor of 10. But when an error occurs, you need to roll back to 10 batches.
- Elastic Horowod can avoid these rollbacks by performing what we call “gracefully remove the worker.” If the driver process finds that the host is available or marked for deletion, it pushes a notification to all workers. So the next time state.com MIT () or the more lightweight state.check_host_updates() is called, a HostsUpdatedInterrupt exception will be thrown. This exception is handled in the same way as “HorovodInternalError”, except that the parameter state is not restored to the last COMMIT, but from the current live parameter.
- In general, if your hardware is reliable and stable, and your orchestration system provides sufficient alerts when task nodes are removed, you can call the state.mit () function at low frequencies. At the same time, only the relatively time-consuming state-check_host_updates () is called at the end of each batch to check node changes.
The specific code is as follows:
@hvd.elastic.run def train(state): for state.epoch in range(state.epoch, epochs): for state.batch in range(state.batch, batches_per_epoch): data, target = get_random_batch() train_one_batch(data, target) if state.batch % batches_per_commit == 0: State.com MIT () # Periodically save state.batch = 0Copy the code
3.2 Exception Handling
We can see the biggest difference between HorovodInternalError and HostsUpdatedInterrupt:
- HorovodInternalError exception: When hvd.elastice. run catches this exception, all states are restored using the latest commit.
- HostsUpdatedInterrupt exception: The HostsUpdatedInterrupt exception is handled in the same way as HorovodInternalError, except that the parameter state is not restored to the last COMMIT, but is restored from the current real-time parameter.
I emphasize this because I’ll show you how to do different recovery later.
3.3 the Commit
When a user calls state.mit, there are two actions: one is to save the state. One is to call check_host_updates to check for updates.
class State(object):
"""State representation used for tracking in memory state across workers."""
def commit(self):
self.save()
self.check_host_updates()
Copy the code
Here save calls the save operation of State, which, in conjunction with this article, is the TorchState save operation described below.
In addition, check_host_updates will raise HostsUpdatedInterrupt. If the number of hostSupdatedinterrupts changes, sync is required. If the number of nodes changes, sync is required. Removed is a value of 1, which could be improved. HostUpdateResult. Removed is too subtle in this case.
\
class HostUpdateResult(IntFlag): no_update = 0 removed = 1 added = 2 mixed = removed | added def check_host_updates(self): """Checks that a notification has been sent indicating that hosts can be added or will be removed. Raises a `HostsUpdatedInterrupt` if such a notification has been received. """ # Iterate through the update messages sent from the server. If the update timestamp # is greater than the last update timestamp, then trigger a HostsUpdatedException. last_updated_timestamp = prev_timestamp = self._last_updated_timestamp all_update = HostUpdateResult.no_update while not self._host_messages.empty(): timestamp, update = self._host_messages.get() if timestamp > last_updated_timestamp: last_updated_timestamp = timestamp all_update |= update # In order to ensure all workers raise the exception at the same time, we need to sync # the updated state across all the workers. # TODO(travis): this should be a max allreduce to account for changes in rank 0 prev_timestamp, self._last_updated_timestamp, all_update = \ self._bcast_object((prev_timestamp, last_updated_timestamp, all_update)) # At this point, updated state is globally consistent across all ranks. if self._last_updated_timestamp > prev_timestamp: Raise HostsUpdatedInterrupt(all_update == Hostupdateresult.removed) # Raise HostsUpdatedInterrupt(all_update == Hostupdateresult.removed) # RaiseCopy the code
0x04 State
Let’s move on to exception handling logic, specifically around State. In the case of State, let’s recall its logic in resuming training.
4.1 Recovery Training
When retraining, two exceptions are thrown:
- If ring AllReduce is involved, the exception HorovodInternalError(e) is raised.
- HostsUpdatedInterrupt is thrown if the driver finds a node marked as new or removed through the node discovery script.
The following processing is then performed:
def run_fn(func, reset): @functools.wraps(func) def wrapper(state, *args, **kwargs): notification_manager.init() notification_manager.register_listener(state) skip_sync = False try: while True: Try: return func(state, *args, **kwargs) except HorovodInternalError: State.restore () # skip_sync = False Skip_sync = e.skip_sync # Record whether synchronization is required. Reset () state.on_reset() # finally: notification_manager.remove_listener(state) return wrapperCopy the code
The logic is as follows:
+------------------------------------------------------------------------------+ | Worker | | | | +------------------------------------------------------------------------+ | | | run_fn | | | | +----------------------------------+ | | | | | while True: | | | | | | | | | | | v | | | | | | | | | | state.sync() | | | | | + | | | | | | | | | | | | | | | | | v | | | | | +------------------+---------------+ | | | | | | train | | | | | | | | | | | | | | optimizer.apply_gradients +---------+ | | | | | | | | | | | | | +-------+ state.commit() | | | | | | | | | | | | | | | | +----------------------------------+ | | | | | | | | | | | | | v v | | | | | HostsUpdatedInterrupt HorovodInternalError | | | | | + | | | | | + | | | | | | | | | | | | | | v | | | | | | state.restore() | | | | | | + | | | | | | | | | | | | +------------------+ <------------------+ | | | | | | | | | | | | | | | | | | | v v | | | | | reset() | | | | | | | | | | state.on_reset() | | | | | | | | | | + | | | | | | | | | | | +-----------------------------------> | | | | | | | +------------------------------------------------------------------------+ | | | +------------------------------------------------------------------------------+Copy the code
Since there are a lot of state operations involved here, let’s look at TorchState:
4.2 TorchState
First, let’s look at how TorchState is used. When called, use the following method to generate a TorchState:
state = hvd.elastic.TorchState(model, optimizer, batch=0, Epoch =0) state.register_reset_callbacks([on_state_reset]) # Register user-defined method on_state_reset train(state)Copy the code
Second, let’s look at the definition of TorchState, where the sync, restore, and reset methods are called during recovery training.
In the init function __init__, a handler is set. In our case, train_sampler and val_sampler are configured with their respective handlers, SamplerStateHandler.
TorchState inherits ObjectState, and ObjectState inherits State, so the commit code self.save() calls Torchstate.save, Here, samplerStateHandler.save is called.
class TorchState(ObjectState): """State representation of a PyTorch training process. Multiple models and optimizers are supported by providing them as kwargs. During initialization, `TorchState` will assign attributes for every keyword argument, and handle its state synchronization. Args: model: Optional PyTorch model. optimizer: Optional PyTorch optimizer. kwargs: Attributes sync, will be exposed as attributes of the object. If a handler exists for the attribute type, it will be used to sync the object, otherwise it will be handled an ordinary Python object. """ def __init__(self, model=None, optimizer=None, **kwargs): Kwargs.update (dict(model=model, optimizer=optimizer)) # Handlers, kwargs = _get_handlers(kwargs) for name, handler in self._handlers.items(): setattr(self, name, handler.value) super(TorchState, self).__init__(bcast_object=broadcast_object, get_rank=rank, **kwargs) def save(self): for handler in self._handlers.values(): Save super(TorchState, self).save() def restore(self): Values (): handler.restore() Super (TorchState, self).restore() def sync(self): # for handlers in self._handlers. Values (): Handler.sync () # This calls the sync method of sampler. super(TorchState, self).sync() def __setattr__(self, name, value): if hasattr(self, name) and name in self._handlers: self._handlers[name].set_value(value) super().__setattr__(name, value)Copy the code
The base class code has:
class State(object): def on_reset(self): Self._reset_messages = queue.queue () self.reset()Copy the code
4.3 set the handler
In the previous section, we saw that both reset and restore are called to _Handlers, so we need to analyze them further.
The first is how to set up the handler. See the following code, which uses a global configuration _handler_registry to specify which handler handles which type instances, such as (ElasticSampler, SamplerStateHandler), SamplerStateHandler is the handler used to handle ElasticSampler.
_handler_registry = [ (torch.nn.Module, ModelStateHandler), (torch.optim.Optimizer, OptimizerStateHandler), (ElasticSampler, SamplerStateHandler), # SamplerStateHandler [def get_handler_registry(): return _handler_registry def set_handler_registry(registry): global _handler_registry _handler_registry = registry def _get_handler(v): V = train_sampler, train_sampler, val_sampler = ElasticSampler Then a SamplerStateHandler is built and returns for handler_type, handler_cls in _handler_registry: if isinstance(v, handler_type): Return handler_cls(v) # call SamplerStateHandler(train_sampler) to generate handlers. Handlers = {} # WHEN K,v is train_sampler=train_sampler, Train_sampler for k, v in kwargs.items(): handler = _get_handler(v) if handler: handlers[k] = handler else: remainder[k] = v return handlers, remainderCopy the code
4.4 SamplerStateHandler
Now that you know that ElasticSampler is handled by SamplerStaeHandler, let’s examine the SamplerStateHandler.
After initialization, self.value is the ElasticSampler, for our previous analysis.
The state of The ElasticSampler will be saved during initialization, and will be used to restore the state of the ElasticSampler if an error occurs.
At the same time, save will also be called for recovery, which we’ll examine in a moment.
class SamplerStateHandler(StateHandler): def __init__(self, sampler): Super ().__init__(sampler) # This will save the ElasticSampler properties and data self._saved_sampler_state = copy.deepcopy(self.value.state_dict()) def save(self): Self._saved_sampler_state = copy. Deepcopy (self.value.state_dict()) def restore(self): # load_state_dict is restored with the original data stored in __init__, Self.value.load_state_dict (self._saved_sampler_state) def sync(self): # 1) Get the set of processed data from all workers world_processed_indices = _union(allgather_object(self.value.processed_indices)) # 2) Replace local processed indices with global indices State_dict = self.value.state_dict() # state_dict['processed_indices'] = World_processed_indices # 3) Broadcast and load the state to make sure we're all in sync Load_state_dict (broadcast_object(state_dict)) load_state_dict(broadcast_object(state_dict))Copy the code
The base classes for SamplerStateHandler are:
class StateHandler(object):
def __init__(self, value):
self.value = value
def save(self):
raise NotImplementedError()
def restore(self):
raise NotImplementedError()
def sync(self):
raise NotImplementedError()
def set_value(self, value):
self.value = value
self.save()
Copy the code
4.5 save
Let’s expand the sequence of save operations.
TorchState inherits ObjectState, ObjectState inherits State, so:
- Torchstate.save is called in the commit code self.save() mentioned earlier.
- Torchstate. save calls samplerStateHandler. save.
- Samplerstatehandler. save saves ElasticSampler’s epoch and Processed_indices.
In this way, the model’s state and ElasticSampler’s state are saved at regular commit times, which will be used during recovery training. As shown below:
+---------------------------+ | TorchState | | | | commit | | + | | | | | | 1 | | | | | v | | save | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + | | 2 | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + |SamplerStateHandler | | | | | | | | | | | | | | | def save(self): v | | | | _saved_sampler_state = copy.deepcopy( value.state_dict() ) | | + | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + | | | 3 | | +------------------------------------------+ | ElasticSampler | | | | | | | | | | | | def state_dict(self): | | | return dict( v | | self.epoch, | | self.processed_indices | | ) | | | +------------------------------------------+Copy the code
It’s still hard to understand just looking at the static definition, so you need to analyze the dynamic process. Because there are two kinds of anomalies, we will analyze them separately.
Recall the biggest difference between the two exceptions:
- HorovodInternalError exception: When hvd.elastice. run catches this exception, all states are restored using the latest commit.
- HostsUpdatedInterrupt exception: The HostsUpdatedInterrupt exception is handled in the same way as HorovodInternalError, except that the parameter state is not restored to the last COMMIT, but is restored from the current real-time parameter.
4.6 HostsUpdatedInterrupt
HostsUpdatedInterrupt is thrown if the driver finds a node marked as new or removed through the node discovery script. This is not a critical exception, so we can continue to train the epoch, but remove the data already processed by the epoch from the subsequent training data. Therefore, the parameter state can not be restored to the last COMMIT, but from the current real-time parameter.
In the following code, we will keep only HostsUpdatedInterrupt code.
def run_fn(func, reset): @functools.wraps(func) def wrapper(state, *args, **kwargs): notification_manager.init() notification_manager.register_listener(state) skip_sync = False try: while True: If not skip_sync: state.sync() # 3) Return func(state, *args, **kwargs) # # 1) perform exception handling skip_sync = e.skip_sync # 2.1) record whether synchronization is required. Set num_replicas state in ElasticSampler. On_reset () notification_manager.remove_listener(state) return wrapperCopy the code
After an exception occurs,
-
1) HostsUpdatedInterrupt indicates that this epoch needs further training, so exception handling is performed, where only:
- 1.1) Whether to synchronize the record book when handling exceptions: skip_sync = e.skip_sync.
-
2) This step is mainly to restart HVD and change the number of workers. Call State’s own reset() method (code in horovod/ Torch /elastic/__init__.py), which will:
-
2.1) Call shutdown() to end the task.
-
2.2) Call init() to call _basics. Init and finally re-establish the MPI context, so hvd.size() is changed according to the latest number of workers. Num_replicas will be changed accordingly in ElasticSampler.__iter__.
-
3) This step is to remove the trained data, and the obtained data are not trained. If synchronization is required, state.sync() is called, which calls the samplerStateHandler.sync method, which internally:
-
3.1) SamplerStateHandler will use set communication to collect processed_indices from all workers and assign it to world_processed_indices, which is the data index processed by all workers.
-
Elasticsampler. state_dict to get references to local ElasticSampler.epoch and ElasticSampler.processed_indices. Then assign world_processed_indices to state_dict[‘processed_indices’], The local ElasticSampler. Processed_indices is the data index that all workers have processed.
-
Self.value.load_state_dict (broadcast_object(state_dict))
- Broadcast so that after synchronization, all workers have the same state_dict[‘processed_indices’] data.
- Load_state_dict willCall ElasticSampler. Reset again, this reset will change
num_replicas
, will also be removed from the total dataprocessed_indices
, get the newremaining_indices
, thus the follow-up__iter__
, the policy to extract index is changed accordingly.
-
-
4) So this removes the trained data, so the obtained remaining_indices data are not trained. Therefore, during the retraining, the data already trained will not be used to repeat the training again in this epoch, but will be recovered from the current real-time parameters.
- Retraining calls return func(state, *args, **kwargs) for training, which is handled here
ElasticSampler.__iter__
。 - When using
__iter__
When I get the next batch of data,self.indices = self.remaining_indices[:]
It will only be extracted from untrained data.
- Retraining calls return func(state, *args, **kwargs) for training, which is handled here
The specific logic is as follows:
-
+----------------------------------------------------------------------------------------------------------------------- + | Worker | | | | +-----------------------------------------------------------------------------------------------------------------+ | | | run_fn | | | | +-----------------------------------------------------------------------------+ | | | | | while True: | | | | | | | | | | | v 3) | | | | | state.sync() +------------------------------------------+----------------------+ | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | | | | | | | v | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + 3.1) 3.2) | | | | | | | | "| | | | | | | | | | | | | | | | | | optimizer.apply_gradients +---------+ | | | | | | | | + | v | | | | | | +-------+ state.commit() | | | | | | | | | + | ElasticSampler.load_state_dict | | | | | | | +----------------------------------+ | + | | | | | | | | | | | | | | | v v | | | | | | | HostsUpdatedInterrupt HorovodInternalError v | | | | | | + ElasticSampler.reset | | | | | | + | + | | | | | | | | | | | | | | | | 1) v | | | | | | | | state.restore() v | | | | | | | + +-----------+-----------------+ | | | | | | | | | ElasticSampler | | | | | | | +------------------+ <------------------+ | | | | | | | | | | | remaining_indices | | | | | | | | | | | | | | | | | v v | num_samples | | | | | | | reset() | | | | | | | | 2) | total_size | | | | | | | state.on_reset() | | | | | | | | | epoch | | | | | | | + | | | | | | | | | | processed_indices | | | | | | | | | | | | | | | | | | state_dict <-------------+ | | | | | | | | | | | | | | +-----------------------------+ | | | | | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- ^ | | | | | | | +-----------------------------------------------------------------------------------------------------------------+ | | | +----------------------------------------------------------------------------------------------------------------------- +Copy the code
The mobile phone is as follows:
4.7 HorovodInternalError
If ring AllReduce is involved, the exception HorovodInternalError(e) is raised. HorovodInternalError is a critical exception where the current state of the epoch is not meaningful and should be restored in the latest COMMIT.
In the following code, we keep only the HorovodInternalError code.
def run_fn(func, reset): @functools.wraps(func) def wrapper(state, *args, **kwargs): notification_manager.init() notification_manager.register_listener(state) skip_sync = False try: while True: If not skip_sync: state.sync() # 3) Return func(state, *args, **kwargs) # except HorovodInternalError State.restore () #1.1) This is different from HostsUpdatedInterrupt skip_sync = False # 1.2) records need to be synchronized reset() # 2) This calls _basics. Init to reinitialize horovod, Set num_replicas state in ElasticSampler. On_reset () notification_manager.remove_listener(state) return wrapperCopy the code
HorovodInternalError and HostsUpdatedInterrupt have almost the same code path, with the addition of a step state.restore().
Why look at node changes here too? Since Horovod checks for node changes periodically, HorovodInternalError may occur when a node changes, but it has not been detected yet, so it can be processed at the same time.
The specific logic is:
-
1) HorovodInternalError indicates that this epoch needs to resume training, so exception processing is performed first:
-
1.1) state. Restore () will call SamplerStateHandler. Restore (here and HostsUpdatedInterrupt treatment differences).
- The ElasticSampler. Load_state_dict method is then called
SamplerStateHandler.__init__
orSamplerStateHandler.save
Restore ElasticSampler in the original saved data. The saved data is processed_indices and EPOCH. - ElasticSampler. Load_state_dict then calls ElasticSampler. Reset with processed_indices to remove the trained data. The latest remaining_indices data is untrained (for processed_indices saved last time).
- The ElasticSampler. Load_state_dict method is then called
-
1.2) Record book exception processing needs synchronization: skip_sync = False.
-
-
2) This step is mainly to restart HVD. Call the reset() method of State itself (code in horovod/ Torch /elastic/__init__.py), which will:
- 2.1) Call shutdown() to end the task.
- 2.2) Call init(), thereby calling _basics. Init, and finally re-establish the MPI context.
-
3) This step is to remove the trained data, and the obtained data are not trained. Because synchronization is required here, state.sync() is called, which calls the samplerStateHandler.sync method, which internally:
-
3.1) SamplerStateHandler will use set communication to collect processed_indices from all workers and assign it to world_processed_indices, which is the data index processed by all workers. Note that this step reverts to the last COMMIT state, since the original data saved in __init__ or save is used to restore.
-
Elasticsampler. state_dict to get references to local ElasticSampler.epoch and ElasticSampler.processed_indices. Then assign world_processed_indices to state_dict[‘processed_indices’], The local ElasticSampler. Processed_indices is the data index that all workers have processed.
-
Self.value.load_state_dict (broadcast_object(state_dict))
- Broadcast so that after synchronization, all workers have the same state_dict[‘processed_indices’] data.
- Load_state_dict willCall ElasticSampler. Reset again, this reset will change
num_replicas
, will also be removed from the total dataprocessed_indices
, get the newremaining_indices
, thus the follow-up__iter__
, the policy to extract index is changed accordingly.
-
-
4) This is to restore the state of the last COMMIT of the epoch for training.
- Retraining calls return func(state, *args, **kwargs) for training, which is handled here
ElasticSampler.__iter__
。 - When using
__iter__
When I get the next batch of data,self.indices = self.remaining_indices[:]
It will only be extracted from untrained data.
- Retraining calls return func(state, *args, **kwargs) for training, which is handled here
The specific logic is shown below:
+--------------------------------------------------------------------------------------------------------------------+ | Worker | | | | +--------------------------------------------------------------------------------------------------------------+ | | | run_fn | | | | +-----------------------------------------------------------------------------+ | | | | | while True: | | | | | | | | | | | v 3 | | | | | state.sync() +-----------------------------------------------------------------+ | | | | | | | | | | | + +--------------+ | | | | | | | | | | | | | | | | | | | | | | | | v | v | | | | | | +------------------+---------------+ | | | | | | | | train | | SamplerStateHandler.restore | | | | | | | | | + | | | | | | | optimizer.apply_gradients +---------+ | | | | | | | | | + | | | | | | | | | +-------+ state.commit() | | v | | | | | | | | + | | ElasticSampler.load_state_dict | | | | | | | +----------------------------------+ | | + | | | | | | | | | | | | | | | | v v | | | | | | | | HostsUpdatedInterrupt HorovodInternalError | v | | | | | | + | ElasticSampler.reset | | | | | | + | | + | | | | | | | | | | | | | | | | | v 1 | | | | | | | | | state.restore()+-----+ v | | | | | | | + +-----------+-----------------+ | | | | | | | | | ElasticSampler | | | | | | | +------------------+ <------------------+ | | | | | | | | | | | remaining_indices | | | | | | | | | | | | | | | | | v v | num_samples | | | | | | | reset() 2 | | | | | | | | | total_size | | | | | | | state.on_reset() | | | | | | | | | epoch | | | | | | | + | | | | | | | | | | processed_indices | | | | | | | | | | | | | | | | | | state_dict <-------------+ | | | | | | | | | | | | | | +-----------------------------+ | | | | | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- ^ | | | | | | | +--------------------------------------------------------------------------------------------------------------+ | | | +--------------------------------------------------------------------------------------------------------------------+Copy the code
The mobile phone is as follows:
4.8 ElasticSampler.__iter__
One problem we haven’t examined so far is when to call ElasticSampler.__iter__
def run_fn(func, reset): @functools.wraps(func) def wrapper(state, *args, **kwargs): notification_manager.init() notification_manager.register_listener(state) skip_sync = False try: while True: if not skip_sync: state.sync() try: Return func(state, *args, **kwargs) except HorovodInternalError: state.restore() skip_sync = False except HostsUpdatedInterrupt as e: skip_sync = e.skip_sync reset() state.on_reset() finally: notification_manager.remove_listener(state) return wrapperCopy the code
Elastic logic encapsulates full_train with annotations, so func is full_train.
@hvd.elastic.run
def full_train(state):
while state.epoch < args.epochs:
train(state)
validate(state.epoch)
save_checkpoint(state.epoch)
end_epoch(state)
Copy the code
Let’s look at the main code for train:
def train(state): model.train() epoch = state.epoch with tqdm(...) Elasticsampler. __iter__ for idx, (data, target) in enumerate(train_loader): # Split data into sub-batches of size batch_size for i in range(0, len(data), args.batch_size): data_batch = data[i:i + args.batch_size] target_batch = target[i:i + args.batch_size] output = model(data_batch) train_accuracy.update(accuracy(output, target_batch)) loss = F.cross_entropy(output, target_batch) train_loss.update(loss) # Average gradients among sub-batches loss.div_(math.ceil(float(len(data)) / args.batch_size)) loss.backward() # Elastic Horovod: record which samples were processed this batch # so we do not reprocess them if a reset event occurs state.train_sampler.record_batch(idx, allreduce_batch_size) # Gradient is applied across all ranks optimizer.step() state.commit()Copy the code
So we can work out the general logic:
- When an error occurs, train will be called again using enumerate(train_loader)
ElasticSampler.__iter__
. - Num_replicas has been set before reset, so now it is time to reset the extraction policy according to the new world size and Remaining_indices.
Def __iter__(self): self.remaining_indices = self.remaining_indices[:] # Shuffle indices across workers deterministically in place seed = self.seed + self.epoch random.Random(seed).shuffle(self.indices) # add extra samples to make it evenly divisible self.indices += self.indices[:(self.total_size - len(self.indices))] assert len(self.indices) == self.total_size # subsample # How does this worker traverse? The starting index is self.rank, the ending index is the total data length, Self. Rank :self.total_size:self.num_replicas] Assert Len (self.indices) == Self. num_samples # Return iter(self.indices)Copy the code
The specific logic is as follows
1) Set the num_replicas in reset.
2) In ElasticSampler.__iter__ re-determine the data extraction strategy based on the new World size and Remaining_indices.
+----------------------------------------------------------------------------------------------------------------+ | Worker | | | | +----------------------------------------------------------------------------------------------------------+ | | | run_fn | | | | +----------------------------------+ | | | | | while True: | | | | | | | | | | | v | | | | | | | | | | state.sync() | | | | | + | | | | | | | | | | | | | | | | | v | | | | | +--------------------------------+ +------------------+---------------+ | | | | | | ElasticSampler | | train | | | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + | | optimizer. Apply_gradients + -- -- -- -- -- -- -- -- -- + | | | | | | | __iter__ | | 2) | | | | | | | | | | | | <------------+ enumerate(train_loader) | | | | | | | | | | | | | | | | | | | | | remaining_indices | | +-------+ state.commit() | | | | | | | | | | | | | | | | | | | | | | | | | +----------------------------------+ | | | | | | | | num_replicas | | v v | | | | | | | | | HostsUpdatedInterrupt HorovodInternalError | | | | | | | ^ | | + | | | | | | | | | | + | | | | | | | +---------------------------+ | | | | | | | | +--------------------------------+ | v | | | | | | | state.restore() | | | | | | | + | | | | | | | | | | | | | | +------------------+ <------------------+ | | | | | | | | | | | | | | | | | | | | | | 1) v v | | | | | + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + reset () | | | | | | | | | | state.on_reset() | | | | | | | | | | + | | | | | | | | | | | +-----------------------------------> | | | | | | | +----------------------------------------------------------------------------------------------------------+ | | | +----------------------------------------------------------------------------------------------------------------+Copy the code
The mobile phone is as follows:
At this point, the analysis of how elastic training can be restored is complete and may be continued with the Pytorch Distributed Optimizer.
0xEE Personal information
★★★★ Thoughts on life and technology ★★★★★
Wechat official account: Rosie’s Thoughts
0 XFF reference
PyTorch Chinese Manual (2) – Automatic derivation
The optimizer in PyTorch, Optimizer.param_groups
PyTorch Learning Note 6– Case 2:PyTorch Neural Network (MNIST CNN)
Github.com/chenyuntc/p…