Skip to content

Commit

Permalink
Fix DDP checkpoint (#1415)
Browse files Browse the repository at this point in the history
* fix

* finamlize fix

* rollback

* fix

* simplyfy broadcastin

* fix

* fix
  • Loading branch information
Louis-Dupont authored Aug 30, 2023
1 parent 32fc041 commit 3e1019f
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import execute_and_distribute_from_master


try:
Expand All @@ -16,6 +17,7 @@
logger = get_logger(__name__)


@execute_and_distribute_from_master
def generate_run_id() -> str:
"""Generate a unique run ID based on the current timestamp.
Expand All @@ -35,11 +37,12 @@ def is_run_dir(dirname: str) -> bool:

def get_latest_run_id(experiment_name: str, checkpoints_root_dir: Optional[str] = None) -> Optional[str]:
"""
:param experiment_name: Name of the experiment.
:param checkpoints_root_dir: Path to the directory where all the experiments are organised, each sub-folder representing a specific experiment.
:param experiment_name: Name of the experiment.
:param checkpoints_root_dir: Path to the directory where all the experiments are organised, each sub-folder representing a specific experiment.
If None, SG will first check if a package named 'checkpoints' exists.
If not, SG will look for the root of the project that includes the script that was launched.
If not found, raise an error.
:return: Latest valid run ID. in the format "RUN_<year>"
"""
experiment_dir = get_experiment_dir_path(checkpoints_root_dir=checkpoints_root_dir, experiment_name=experiment_name)

Expand All @@ -51,7 +54,7 @@ def get_latest_run_id(experiment_name: str, checkpoints_root_dir: Optional[str]
f"Trying to load the n-1 most recent run..."
)
else:
return run_dir
return os.path.basename(run_dir)


def validate_run_id(run_id: str, experiment_name: str, ckpt_root_dir: Optional[str] = None):
Expand Down
94 changes: 93 additions & 1 deletion src/super_gradients/common/environment/ddp_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
import socket
from functools import wraps
import os
from typing import Any, List, Callable

import torch
import torch.distributed as dist

from super_gradients.common.environment.device_utils import device_config
from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers
Expand Down Expand Up @@ -77,3 +81,91 @@ def find_free_port() -> int:
sock.bind(("", 0))
_ip, port = sock.getsockname()
return port


def get_local_rank():
"""
Returns the local rank if running in DDP, and 0 otherwise
:return: local rank
"""
return dist.get_rank() if dist.is_initialized() else 0


def require_ddp_setup() -> bool:
from super_gradients.common import MultiGPUMode

return device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and device_config.assigned_rank != get_local_rank()


def is_ddp_subprocess():
return torch.distributed.get_rank() > 0 if dist.is_initialized() else False


def get_world_size() -> int:
"""
Returns the world size if running in DDP, and 1 otherwise
:return: world size
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()


def get_device_ids() -> List[int]:
return list(range(get_world_size()))


def count_used_devices() -> int:
return len(get_device_ids())


def execute_and_distribute_from_master(func: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorator to execute a function on the master process and distribute the result to all other processes.
Useful in parallel computing scenarios where a computational task needs to be performed only on the master
node (e.g., a computational-heavy calculation), and the result must be shared with other nodes without
redundant computation.
Example usage:
>>> @execute_and_distribute_from_master
>>> def some_code_to_run(param1, param2):
>>> return param1 + param2
The wrapped function will only be executed on the master node, and the result will be propagated to all
other nodes.
:param func: The function to be executed on the master process and whose result is to be distributed.
:return: A wrapper function that encapsulates the execute-and-distribute logic.
"""

@wraps(func)
def wrapper(*args, **kwargs):
# Run the function only if it's the master process
if device_config.assigned_rank <= 0:
result = func(*args, **kwargs)
else:
result = None

# Broadcast the result from the master process to all nodes
return broadcast_from_master(result)

return wrapper


def broadcast_from_master(data: Any) -> Any:
"""
Broadcast data from master node to all other nodes. This may be required when you
want to compute something only on master node (e.g computational-heavy metric) and
don't want to waste CPU of other nodes doing the same work simultaneously.
:param data: Data to be broadcasted from master node (rank 0)
:return: Data from rank 0 node
"""
world_size = get_world_size()
if world_size == 1:
return data
broadcast_list = [data] if dist.get_rank() == 0 else [None]
dist.broadcast_object_list(broadcast_list, src=0)
return broadcast_list[0]
3 changes: 2 additions & 1 deletion src/super_gradients/common/environment/omegaconf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from omegaconf import OmegaConf, DictConfig

from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path
from hydra.experimental.callback import Callback


Expand Down Expand Up @@ -72,6 +71,8 @@ def get_cls(cls_path: str):


def hydra_output_dir_resolver(ckpt_root_dir: str, experiment_name: str) -> str:
from super_gradients.common.environment.checkpoints_dir_utils import get_checkpoints_dir_path

return get_checkpoints_dir_path(experiment_name=experiment_name, ckpt_root_dir=ckpt_root_dir)


Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def _setup_dir(self):
# Only if it exists, i.e. if hydra was used.
if os.path.exists(source_hydra_path):
destination_hydra_path = os.path.join(self._local_dir, ".hydra")
shutil.copytree(source_hydra_path, destination_hydra_path, dirs_exist_ok=True)
if not os.path.exists(destination_hydra_path):
shutil.copytree(source_hydra_path, destination_hydra_path)

@multi_process_safe
def _init_log_file(self):
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/training/dataloaders/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
from super_gradients.training.utils import get_param
from super_gradients.training.utils.distributed_training_utils import (
wait_for_the_master,
get_local_rank,
)
from super_gradients.common.environment.ddp_utils import get_local_rank
from super_gradients.training.utils.utils import override_default_params_without_nones
from super_gradients.common.environment.cfg_utils import load_dataset_params
import torch.distributed as dist
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/training/datasets/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from super_gradients.common.registry.registry import register_collate_function, register_callback, register_transform
from super_gradients.training.datasets.auto_augment import rand_augment_transform
from super_gradients.training.utils.detection_utils import DetectionVisualization, Anchors
from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
from super_gradients.common.environment.ddp_utils import get_local_rank, get_world_size
from super_gradients.training.utils.utils import AverageMeter


Expand Down
4 changes: 1 addition & 3 deletions src/super_gradients/training/losses/ppyolo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
from super_gradients.common.registry.registry import register_loss
from super_gradients.training.datasets.data_formats.bbox_formats.cxcywh import cxcywh_to_xyxy
from super_gradients.training.utils.bbox_utils import batch_distance2bbox
from super_gradients.training.utils.distributed_training_utils import (
get_world_size,
)
from super_gradients.common.environment.ddp_utils import get_world_size


def batch_iou_similarity(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-9) -> float:
Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/training/metrics/detection_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchmetrics import Metric

import super_gradients
import super_gradients.common.environment.ddp_utils
from super_gradients.common.object_names import Metrics
from super_gradients.common.registry.registry import register_metric
from super_gradients.training.utils import tensor_container_to_device
Expand Down Expand Up @@ -222,7 +223,7 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None):
:return:
"""
if self.world_size is None:
self.world_size = torch.distributed.get_world_size() if self.is_distributed else -1
self.world_size = super_gradients.common.environment.ddp_utils.get_world_size() if self.is_distributed else -1
if self.rank is None:
self.rank = torch.distributed.get_rank() if self.is_distributed else -1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import Tensor
from torchmetrics import Metric

import super_gradients.common.environment.ddp_utils
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.common.environment.ddp_utils import is_distributed
from super_gradients.common.object_names import Metrics
Expand Down Expand Up @@ -258,7 +259,7 @@ def _sync_dist(self, dist_sync_fn=None, process_group=None):
:return:
"""
if self.world_size is None:
self.world_size = torch.distributed.get_world_size() if self.is_distributed else -1
self.world_size = super_gradients.common.environment.ddp_utils.get_world_size() if self.is_distributed else -1
if self.rank is None:
self.rank = torch.distributed.get_rank() if self.is_distributed else -1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

__all__ = ["CSPResNetBackbone", "CSPResNetBasicBlock"]

from super_gradients.training.utils.distributed_training_utils import wait_for_the_master, get_local_rank
from super_gradients.training.utils.distributed_training_utils import wait_for_the_master
from super_gradients.common.environment.ddp_utils import get_local_rank


class CSPResNetBasicBlock(nn.Module):
Expand Down
6 changes: 1 addition & 5 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,10 @@
compute_precise_bn_stats,
setup_device,
get_gpu_mem_utilization,
get_world_size,
get_local_rank,
require_ddp_setup,
get_device_ids,
is_ddp_subprocess,
wait_for_the_master,
DDPNotSetupException,
)
from super_gradients.common.environment.ddp_utils import get_local_rank, require_ddp_setup, is_ddp_subprocess, get_world_size, get_device_ids
from super_gradients.training.utils.ema import ModelEMA
from super_gradients.training.utils.optimizer_utils import build_optimizer
from super_gradients.training.utils.sg_trainer_utils import MonitoredValue, log_main_training_params
Expand Down
3 changes: 2 additions & 1 deletion src/super_gradients/training/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from super_gradients.common.decorators.explicit_params_validator import explicit_params_validation
from super_gradients.module_interfaces import HasPredict
from super_gradients.training.pretrained_models import MODEL_URLS
from super_gradients.training.utils.distributed_training_utils import get_local_rank, wait_for_the_master
from super_gradients.training.utils.distributed_training_utils import wait_for_the_master
from super_gradients.common.environment.ddp_utils import get_local_rank
from super_gradients.training.utils.utils import unwrap_model

try:
Expand Down
37 changes: 24 additions & 13 deletions src/super_gradients/training/utils/distributed_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from super_gradients.common.deprecate import deprecated
from super_gradients.common.environment.ddp_utils import init_trainer
from super_gradients.common.data_types.enum import MultiGPUMode
from super_gradients.common.environment.argparse_utils import EXTRA_ARGS
Expand All @@ -27,6 +28,14 @@
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.factories.type_factory import TypeFactory

from super_gradients.common.environment.ddp_utils import get_local_rank as _get_local_rank
from super_gradients.common.environment.ddp_utils import is_ddp_subprocess as _is_ddp_subprocess
from super_gradients.common.environment.ddp_utils import get_world_size as _get_world_size
from super_gradients.common.environment.ddp_utils import get_device_ids as _get_device_ids
from super_gradients.common.environment.ddp_utils import count_used_devices as _count_used_devices
from super_gradients.common.environment.ddp_utils import require_ddp_setup as _require_ddp_setup


logger = get_logger(__name__)


Expand Down Expand Up @@ -145,40 +154,42 @@ def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoad
bn.momentum = momentums[i]


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_get_local_rank)
def get_local_rank():
"""
Returns the local rank if running in DDP, and 0 otherwise
:return: local rank
"""
return dist.get_rank() if dist.is_initialized() else 0


def require_ddp_setup() -> bool:
return device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL and device_config.assigned_rank != get_local_rank()
return _get_local_rank()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_is_ddp_subprocess)
def is_ddp_subprocess():
return torch.distributed.get_rank() > 0 if dist.is_initialized() else False
return _is_ddp_subprocess()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_get_world_size)
def get_world_size() -> int:
"""
Returns the world size if running in DDP, and 1 otherwise
:return: world size
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
return _get_world_size()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_get_device_ids)
def get_device_ids() -> List[int]:
return list(range(get_world_size()))
return _get_device_ids()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_count_used_devices)
def count_used_devices() -> int:
return len(get_device_ids())
return _count_used_devices()


@deprecated(deprecated_since="3.2.1", removed_from="3.5.0", target=_require_ddp_setup)
def require_ddp_setup() -> bool:
return _require_ddp_setup()


@contextmanager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm

from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training.utils.distributed_training_utils import get_local_rank, get_world_size
from super_gradients.common.environment.ddp_utils import get_local_rank, get_world_size
from torch.distributed import all_gather

from super_gradients.training.utils.utils import infer_model_device
Expand Down

0 comments on commit 3e1019f

Please sign in to comment.