Skip to content

Commit

Permalink
Load ub_cfg from hydra config (#7003)
Browse files Browse the repository at this point in the history
* Pass tp config via hydra

Signed-off-by: Jan Baczek <[email protected]>

* Remove self.ub_cfgs field - it isn't used anywhere else

Signed-off-by: Jan Baczek <[email protected]>

* Allow tp_overlap tree substitution in hydra config

Signed-off-by: Jan Baczek <[email protected]>

* Add warning in case of usage of the default tp config

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change warning message

Signed-off-by: Jan Baczek <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add compute capability resolver

Signed-off-by: Jan Baczek <[email protected]>

* Bugfix

Signed-off-by: Jan Baczek <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add guards to pynvml import

Signed-off-by: Jan Baczek <[email protected]>

---------

Signed-off-by: Jan Baczek <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: yaoyu-33 <[email protected]>
  • Loading branch information
3 people authored Aug 12, 2023
1 parent b95a169 commit 2b9c5f4
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 10 deletions.
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- optional [email protected]_tp_comm_overlap_cfg:

name: megatron_gpt
restore_from_path: null # used when starting from a .nemo file

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# UB communicator configurations
# Model configs: A100/175B/TP4/MBS1/SeqLen2K/BF16

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0

fc2_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# UB communicator configurations
# Model configs: A100/175B/TP4/MBS2/SeqLen2K/BF16

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 8
num_splits: 4
set_sm_margin: 0

fc2_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# UB communicator configurations
# Model configs: H100/175B/TP4/MBS1/SeqLen2K/FP8

# Bulk overlap with AllGather / ReduceScatter
qkv_dgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 8
cga_size: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
cga_size: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 1

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 24
cga_size: 2
num_splits: 4
set_sm_margin: 1

fc2_fprop:
method: pipeline
num_sm: 20
cga_size: 2
num_splits: 4
set_sm_margin: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# UB communicator configurations
# Model configs: H100/175B/TP8/MBS2/SeqLen2K/FP8

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 8
cga_size: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 16
cga_size: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 16
cga_size: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 1

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 16
cga_size: 2
num_splits: 4
set_sm_margin: 1

fc2_fprop:
method: pipeline
num_sm: 24
cga_size: 2
num_splits: 4
set_sm_margin: 1
Original file line number Diff line number Diff line change
Expand Up @@ -522,20 +522,17 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
return loss_mean

def initialize_ub_func(self):
ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfgs is None:
warnings.warn(
"Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config."
)

input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
ub_cfgs = None
if ub_cfg_file_name is not None:
try:
import yaml

with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
logging.error(f"Fail to read ub_tp_comm_overlap config file: {ub_cfg_file_name}.")

te_module.base.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
Expand Down
21 changes: 21 additions & 0 deletions nemo/core/config/hydra_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,27 @@
from hydra.types import TaskFunction
from omegaconf import DictConfig, OmegaConf


def _get_gpu_name():
try:
import pynvml
except (ImportError, ModuleNotFoundError):
return None

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
cuda_capability, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
pynvml.nvmlShutdown()
if cuda_capability == 8:
return "a100"
elif cuda_capability == 9:
return "h100"
else:
return None


OmegaConf.register_new_resolver("gpu_name", _get_gpu_name)

# multiple interpolated values in the config
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y)

Expand Down

0 comments on commit 2b9c5f4

Please sign in to comment.