diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cba7adb2..58c41f78 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -118,7 +118,7 @@ jobs: src/test/ - name: Test checkpoint (GPU) - image: olmo-core + image: olmo-core-nightly gpus: 2 run: | pytest -v --color=yes --durations=3 -m gpu \ @@ -180,10 +180,11 @@ jobs: gpuCount: ${{ matrix.task.gpus }} constraints: cluster: - - ai2/allennlp-cirrascale - - ai2/allennlp-elanding-a100-40g + # - ai2/allennlp-cirrascale + # - ai2/allennlp-elanding-a100-40g - ai2/pluto-cirrascale - ai2/jupiter-cirrascale-2 + # - ai2/saturn-cirrascale envVars: - name: CUBLAS_WORKSPACE_CONFIG value: ":16:8" diff --git a/CHANGELOG.md b/CHANGELOG.md index 542ccd53..fdfd5b51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `DownstreamEvaluatorCallbackConfig` class for running in-loop downstream eval via [OLMo-in-loop-evals](https://github.com/allenai/OLMo-in-loop-evals). + +### Removed + +- Removed `flash-attn` from the Beaker images since `flash-attn` currently can't be built for torch 2.5.1. We are waiting on updates from the `flash-attn` maintainers. See https://github.com/Dao-AILab/flash-attention/issues/1302. + ## [v1.5.0](https://github.com/allenai/OLMo-core/releases/tag/v1.5.0) - 2024-10-23 ### Added diff --git a/Makefile b/Makefile index d5cda160..cba52330 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,13 @@ -BASE_IMAGE = ghcr.io/allenai/pytorch:2.4.1-cuda12.1-python3.11 +# NOTE: make sure CUDA versions match across these variables +BASE_IMAGE = ghcr.io/allenai/pytorch:2.5.1-cuda12.1-python3.11-v2024.10.29 +CUDA_TOOLKIT_VERSION = 12.1.0 +TORCH_CUDA_VERSION = 121 # NOTE: when upgrading the nightly version you also need to upgrade the torch version specification # in 'pyproject.toml' to include that nightly version. -NIGHTLY_VERSION = "2.6.0.dev20241009+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121" -TORCHAO_VERSION = "torchao==0.5.0 --extra-index-url https://download.pytorch.org/whl/cu121" +NIGHTLY_VERSION = "2.6.0.dev20241009+cu121" +TORCHAO_VERSION = "torchao==0.5.0" MEGABLOCKS_VERSION = "megablocks[gg] @ git+https://git@github.com/epwalsh/megablocks.git@epwalsh/deps" -CUDA_TOOLKIT_VERSION = 12.1.0 VERSION = $(shell python src/olmo_core/version.py) VERSION_SHORT = $(shell python src/olmo_core/version.py short) @@ -49,6 +51,7 @@ stable-image : --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg BASE=$(BASE_IMAGE) \ --build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \ + --build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \ --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --target stable \ @@ -62,6 +65,7 @@ nightly-image : --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg BASE=$(BASE_IMAGE) \ --build-arg CUDA_TOOLKIT_VERSION=$(CUDA_TOOLKIT_VERSION) \ + --build-arg TORCH_CUDA_VERSION=$(TORCH_CUDA_VERSION) \ --build-arg MEGABLOCKS_VERSION=$(MEGABLOCKS_VERSION) \ --build-arg TORCHAO_VERSION=$(TORCHAO_VERSION) \ --build-arg NIGHTLY_VERSION=$(NIGHTLY_VERSION) \ diff --git a/pyproject.toml b/pyproject.toml index 5af9c589..89f40d2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "omegaconf", "safetensors", "importlib_resources", + "ai2-olmo-eval==0.2.0", ] [project.urls] diff --git a/src/Dockerfile b/src/Dockerfile index 1a7255a5..a5726d0b 100644 --- a/src/Dockerfile +++ b/src/Dockerfile @@ -13,12 +13,25 @@ WORKDIR /app/build ARG CUDA_TOOLKIT_VERSION RUN conda install -y -c nvidia cuda-toolkit==${CUDA_TOOLKIT_VERSION} +ARG TORCH_CUDA_VERSION + # Build megablocks and grouped-gemm. ENV TORCH_CUDA_ARCH_LIST="8.0 9.0" ENV GROUPED_GEMM_CUTLASS=1 ARG MEGABLOCKS_VERSION -RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" \ - && rm -rf torch-*.whl numpy-*.whl triton-*.whl +RUN pip wheel --no-build-isolation --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \ + "${MEGABLOCKS_VERSION}" + +# Flash-attn from pre-built wheel (can't get this to work at the moment) +#RUN wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl + +# Only keep the target wheels and dependencies with CUDA extensions. +RUN echo "Built wheels:" \ + && ls -lh . \ + && ls -1 | grep -Ev 'megablocks|grouped_gemm|stanford_stk|flash_attn' | xargs rm \ + && echo "Final wheels:" \ + && ls -lh . ######################################################################### # Stable image @@ -26,9 +39,13 @@ RUN pip wheel --no-build-isolation --no-cache-dir "${MEGABLOCKS_VERSION}" \ FROM ${BASE} as stable +ARG TORCH_CUDA_VERSION + # Install torchao. ARG TORCHAO_VERSION -RUN pip install --no-cache-dir ${TORCHAO_VERSION} +RUN pip install --no-cache-dir \ + --extra-index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} \ + ${TORCHAO_VERSION} # Copy and install wheels from build image. COPY --from=build /app/build /app/build @@ -50,5 +67,9 @@ WORKDIR /app/olmo-core FROM stable as nightly +ARG TORCH_CUDA_VERSION + ARG NIGHTLY_VERSION -RUN pip install --no-cache-dir --pre torch==${NIGHTLY_VERSION} +RUN pip install --no-cache-dir --pre \ + --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} \ + torch==${NIGHTLY_VERSION} diff --git a/src/examples/train.py b/src/examples/train.py index deec8307..0cc3d525 100644 --- a/src/examples/train.py +++ b/src/examples/train.py @@ -31,6 +31,7 @@ CheckpointerCallback, CometCallback, ConfigSaverCallback, + DownstreamEvaluatorCallbackConfig, GPUMemoryMonitorCallback, GradClipperCallback, LMEvaluatorCallbackConfig, @@ -133,7 +134,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: .with_callback("config_saver", ConfigSaverCallback()) .with_callback("profiler", ProfilerCallback(enabled=False)) .with_callback( - "evaluator", + "lm_evaluator", LMEvaluatorCallbackConfig( eval_dataset=NumpyDatasetConfig( paths=["/net/nfs/allennlp/llm-data/c4/en/c4-validation.00000-00008.npy"], @@ -147,6 +148,14 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: eval_duration=Duration.steps(10), ), ) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=["hellaswag"], + tokenizer=tokenizer_config, + eval_interval=250, + ), + ) ) return ExperimentConfig( diff --git a/src/olmo_core/data/mixes/__init__.py b/src/olmo_core/data/mixes/__init__.py index b2e79d8b..b8adecc8 100644 --- a/src/olmo_core/data/mixes/__init__.py +++ b/src/olmo_core/data/mixes/__init__.py @@ -1,4 +1,5 @@ import os +from abc import abstractmethod from contextlib import contextmanager from pathlib import Path from typing import Generator, List, Tuple @@ -15,7 +16,8 @@ class DataMixBase(StrEnum): Base class for enumeration of data mixes. """ - def build(self, base_dir: str, tokenizer: TokenizerName) -> Tuple[List[str], List[str]]: + @abstractmethod + def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]: """ Construct the data mix. @@ -37,7 +39,7 @@ class DataMix(DataMixBase): dolma17 = "dolma17" v3_small_ppl_validation = "v3-small-ppl-validation" - def build(self, base_dir: str, tokenizer: TokenizerName) -> Tuple[List[str], List[str]]: + def build(self, base_dir: str, tokenizer: str) -> Tuple[List[str], List[str]]: if not base_dir.endswith("/"): base_dir = base_dir + "/" diff --git a/src/olmo_core/data/tokenizer.py b/src/olmo_core/data/tokenizer.py index 4a66e504..4cf950a3 100644 --- a/src/olmo_core/data/tokenizer.py +++ b/src/olmo_core/data/tokenizer.py @@ -3,15 +3,31 @@ from ..config import Config, StrEnum +__all__ = [ + "TokenizerConfig", + "TokenizerName", +] + class TokenizerName(StrEnum): """ - An enumeration of supported tokenizer names. + An enumeration of tokenizer identifiers commonly used OLMo researchers. """ dolma2 = "allenai/dolma2-tokenizer" + """ + The dolma2 tokenizer. + """ + gpt_neox_olmo_dolma_v1_5 = "allenai/gpt-neox-olmo-dolma-v1_5" + """ + A modified GPT NeoX tokenizer. + """ + gpt2 = "gpt2" + """ + The base GPT2 tokenizer. + """ @dataclass @@ -21,10 +37,29 @@ class TokenizerConfig(Config): """ vocab_size: int + """ + The vocab size. + """ + eos_token_id: int + """ + The end-of-sentence token ID. + """ + pad_token_id: int + """ + The padding token ID. + """ + bos_token_id: Optional[int] = None - identifier: Optional[TokenizerName] = None + """ + The begin-of-sentence token ID. + """ + + identifier: Optional[str] = None + """ + The identifier of the tokenizer. Could be a path or HuggingFace identifier. + """ def padded_vocab_size(self, pad_multiple: int = 128) -> int: """ @@ -35,6 +70,9 @@ def padded_vocab_size(self, pad_multiple: int = 128) -> int: @classmethod def dolma2(cls) -> "TokenizerConfig": + """ + Get a :data:`~TokenizerName.dolma2` tokenizer config. + """ return cls( vocab_size=100278, eos_token_id=100257, @@ -44,6 +82,9 @@ def dolma2(cls) -> "TokenizerConfig": @classmethod def gpt_neox_olmo_dolma_v1_5(cls) -> "TokenizerConfig": + """ + Get a :data:`~TokenizerName.gpt_neox_olmo_dolma_v1_5` tokenizer config. + """ return cls( vocab_size=50280, eos_token_id=50279, @@ -53,6 +94,9 @@ def gpt_neox_olmo_dolma_v1_5(cls) -> "TokenizerConfig": @classmethod def gpt2(cls) -> "TokenizerConfig": + """ + Get a :data:`~TokenizerName.gpt2` tokenizer config. + """ return cls( vocab_size=50280, eos_token_id=50256, diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 6b645832..605e0d3a 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -13,7 +13,7 @@ from ..config import StrEnum from ..exceptions import OLMoConfigurationError, OLMoEnvironmentError -from ..utils import get_default_device, set_env_var +from ..utils import get_default_device, move_to_device, set_env_var OLMO_SHARED_FS_ENV_VAR = "OLMO_SHARED_FS" OLMO_FS_LOCAL_RANK_ENV_VAR = "FS_LOCAL_RANK" @@ -270,6 +270,23 @@ def scatter_object(obj: T, src: int = 0, group: Optional[dist.ProcessGroup] = No return output_list[0] +def all_gather( + tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None +) -> List[torch.Tensor]: + """ + All-gather tensors from the whole group into a list. + """ + if not is_distributed(): + return [tensor] + + shapes = all_gather_object(tensor.shape, group=group) + output_list = [ + move_to_device(torch.zeros(shape, dtype=tensor.dtype), tensor.device) for shape in shapes + ] + dist.all_gather(output_list, tensor, group=group) + return output_list + + def all_gather_object(obj: T, group: Optional[dist.ProcessGroup] = None) -> List[T]: """ All-gather an object using pickle to all ranks in a process group. diff --git a/src/olmo_core/eval/metrics.py b/src/olmo_core/eval/metrics.py index 03a78d9e..e2c87c1d 100644 --- a/src/olmo_core/eval/metrics.py +++ b/src/olmo_core/eval/metrics.py @@ -67,6 +67,10 @@ def __init__( def update( self, value: Union[float, torch.Tensor], weight: Union[float, torch.Tensor] = 1.0 ) -> None: + """ + :param value: The latest value to update the metric with. Could be a tensor of values. + :param weight: The corresponding weight(s) for the value. Should be the same shape as ``value``. + """ value = self.as_tensor(value) weight = torch.broadcast_to(self.as_tensor(weight), value.shape) if value.numel() == 0: @@ -75,6 +79,9 @@ def update( self.weight += weight.sum() def compute(self) -> torch.Tensor: + """ + Computes the mean over the values and weights given. + """ weighted_sum = all_reduce_value( self.weighted_sum, device=self.device, group=self.process_group ) diff --git a/src/olmo_core/train/callbacks/__init__.py b/src/olmo_core/train/callbacks/__init__.py index 44028505..05505104 100644 --- a/src/olmo_core/train/callbacks/__init__.py +++ b/src/olmo_core/train/callbacks/__init__.py @@ -3,7 +3,11 @@ from .comet import CometCallback, CometNotificationSetting from .config_saver import ConfigSaverCallback from .console_logger import ConsoleLoggerCallback -from .evaluator_callback import EvaluatorCallback, LMEvaluatorCallbackConfig +from .evaluator_callback import ( + DownstreamEvaluatorCallbackConfig, + EvaluatorCallback, + LMEvaluatorCallbackConfig, +) from .float8_handler import Float8HandlerCallback from .garbage_collector import GarbageCollectorCallback from .gpu_memory_monitor import GPUMemoryMonitorCallback @@ -27,6 +31,7 @@ "EvaluatorCallback", "Float8HandlerCallback", "LMEvaluatorCallbackConfig", + "DownstreamEvaluatorCallbackConfig", "MoEHandlerCallback", "GarbageCollectorCallback", "GPUMemoryMonitorCallback", diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index 6f486f87..a46d3047 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -1,18 +1,29 @@ import logging from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional -from olmo_core.data import NumpyDatasetConfig, NumpyPaddedFSLDataset -from olmo_core.distributed.utils import get_world_size +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler + +from olmo_core.data import NumpyDatasetConfig, NumpyPaddedFSLDataset, TokenizerConfig +from olmo_core.distributed.utils import get_rank, get_world_size, is_distributed from olmo_core.eval import Evaluator from olmo_core.eval.lm_evaluator import LMEvaluator from olmo_core.exceptions import OLMoConfigurationError -from olmo_core.utils import format_float, move_to_device +from olmo_core.utils import ( + cuda_sync_debug_mode, + format_float, + get_default_device, + move_to_device, +) from ..common import Duration from .callback import Callback, CallbackConfig if TYPE_CHECKING: + from olmo_eval import HFTokenizer + from ..trainer import Trainer log = logging.getLogger(__name__) @@ -65,7 +76,10 @@ def post_step(self): logits, ce_loss, _ = self.trainer.eval_batch( batch, loss_reduction="none", compute_z_loss=False ) - evaluator.update_metrics(batch, ce_loss, logits) + + # NOTE: might have host-device syncs here but that's okay. + with cuda_sync_debug_mode(0): + evaluator.update_metrics(batch, ce_loss, logits) if eval_step % self.trainer.cancel_check_interval == 0: self.trainer.check_if_canceled() @@ -83,10 +97,11 @@ def post_step(self): # NOTE: going to have a host-device sync here but that's okay. It's only once # per evaluator. metrics = [] - for name, value in evaluator.compute_metrics().items(): - value = value.item() - metrics.append(f" {name}={format_float(value)}") - self.trainer.record_metric(f"eval/{evaluator.name}/{name}", value) + with cuda_sync_debug_mode(0): + for name, value in evaluator.compute_metrics().items(): + value = value.item() + metrics.append(f" {name}={format_float(value)}") + self.trainer.record_metric(f"eval/{evaluator.name}/{name}", value) log.info("Eval metrics:\n" + "\n".join(metrics)) # Restore model to train mode. @@ -124,6 +139,7 @@ def build(self, trainer: "Trainer") -> Callback: global_batch_size=eval_batch_size, collator=trainer.data_loader.collator, device=trainer.device, + dp_process_group=trainer.dp_process_group, ) return EvaluatorCallback( evaluators=[evaluator], @@ -131,3 +147,129 @@ def build(self, trainer: "Trainer") -> Callback: log_interval=self.log_interval, eval_duration=self.eval_duration, ) + + +class DownstreamEvaluator(Evaluator): + metric_type_to_label = { + "f1": "F1 score", + "acc": "accuracy", + "len_norm": "length-normalized accuracy", + "pmi_dc": "PMI-DC accuracy", + "ce_loss": "CE loss", + "bpb": "BPB", + } + + def __init__( + self, + *, + name: str, + task: str, + rank_batch_size: int, + tokenizer: "HFTokenizer", + device: Optional[torch.device] = None, + dp_process_group: Optional[dist.ProcessGroup] = None, + ): + from olmo_eval import ICLMetric, build_task + + self.label = task + self.task = build_task(task, tokenizer) + self.metric = ICLMetric(metric_type=self.task.metric_type).to( + device or get_default_device() + ) + sampler: Optional[DistributedSampler] = None + if is_distributed(): + sampler = DistributedSampler( + self.task, # type: ignore + drop_last=True, + shuffle=False, + num_replicas=get_world_size(dp_process_group), + rank=get_rank(dp_process_group), + ) + + rank_batch_size_instances = max(0, rank_batch_size // self.task.max_sequence_length) + log.info( + f"Using per-rank batch size of {rank_batch_size_instances} instances " + f"for downstream eval task '{task}' with max sequence length {self.task.max_sequence_length:,d} tokens" + ) + + data_loader = DataLoader( + self.task, # type: ignore + batch_size=rank_batch_size_instances, + collate_fn=self.task.collate_fn, + num_workers=0, + sampler=sampler, + ) + + super().__init__( + name=name, batches=data_loader, device=device, dp_process_group=dp_process_group + ) + + def update_metrics( + self, batch: Dict[str, Any], ce_loss: torch.Tensor, logits: torch.Tensor + ) -> None: + del ce_loss + self.metric.update(batch, logits) + + def compute_metrics(self) -> Dict[str, torch.Tensor]: + value = self.metric.compute() + label = f"{self.label} ({self.metric_type_to_label[self.task.metric_type]})" + return {label: value} + + def reset_metrics(self) -> None: + self.metric.reset() + + +@dataclass +class DownstreamEvaluatorCallbackConfig(CallbackConfig): + tasks: List[str] + tokenizer: TokenizerConfig + eval_batch_size: Optional[int] = None + eval_interval: int = 1000 + eval_duration: Duration = field(default_factory=lambda: Duration.epochs(1)) + log_interval: int = 5 + + def build(self, trainer: "Trainer") -> Callback: + from olmo_eval import HFTokenizer + + global_eval_batch_size = ( + self.eval_batch_size + if self.eval_batch_size is not None + else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) + ) + rank_eval_batch_size = global_eval_batch_size // get_world_size(trainer.dp_process_group) + if rank_eval_batch_size == 0: + raise OLMoConfigurationError( + f"'eval_batch_size' of {global_eval_batch_size:,d} tokens is too small for the given world size" + ) + + if self.tokenizer.identifier is None: + raise OLMoConfigurationError( + "Tokenizer 'identifier' required to build a concrete tokenizer" + ) + + tokenizer = HFTokenizer( + self.tokenizer.identifier, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + bos_token_id=self.tokenizer.bos_token_id, + ) + + evaluators: List[Evaluator] = [] + for task in self.tasks: + evaluators.append( + DownstreamEvaluator( + name="downstream", + task=task, + rank_batch_size=rank_eval_batch_size, + tokenizer=tokenizer, + device=trainer.device, + dp_process_group=trainer.dp_process_group, + ) + ) + + return EvaluatorCallback( + evaluators=evaluators, + eval_interval=self.eval_interval, + log_interval=self.log_interval, + eval_duration=self.eval_duration, + ) diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 7b960c5c..a40e3fb8 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -308,7 +308,11 @@ def __post_init__(self): if self._bookkeeping_pg is None and is_distributed(): if backend_supports_cpu(): log.info("Creating new process group for bookkeeping") - self._bookkeeping_pg = dist.new_group() + self._bookkeeping_pg = dist.new_group( + ranks=None + if self.dp_process_group is None + else dist.get_process_group_ranks(self.dp_process_group) + ) else: log.warning( "No CPU backend configured, bookkeeping collectives will occur on the default " @@ -460,7 +464,8 @@ def bookkeeping_device(self) -> torch.device: @property def bookkeeping_pg(self) -> Optional[dist.ProcessGroup]: """ - The process group used for bookkeeping collectives. + The process group used for bookkeeping collectives. This should include the same ranks + as the :data:`dp_process_group`. Since bookkeeping collectives might be done in a separate thread, we need a separate process group to avoid potential race conditions. diff --git a/src/olmo_core/train/utils.py b/src/olmo_core/train/utils.py index d6f0a53f..d6402e41 100644 --- a/src/olmo_core/train/utils.py +++ b/src/olmo_core/train/utils.py @@ -20,6 +20,7 @@ get_world_size, is_distributed, ) +from ..utils import cuda_sync_debug_mode from .common import ReduceType log = logging.getLogger(__name__) @@ -131,7 +132,11 @@ def move_metrics( ] metrics_to_move: Optional[torch.Tensor] = None if metrics_to_move_list: - metrics_to_move = torch.stack(metrics_to_move_list).to(device, non_blocking=non_blocking) + # NOTE: this is a known host-device sync so we don't need the warning + with cuda_sync_debug_mode(0): + metrics_to_move = torch.stack(metrics_to_move_list).to( + device, non_blocking=non_blocking + ) # Collect output with moved tensors. target: Dict[int, Dict[str, torch.Tensor]] = OrderedDict() diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index cc049de9..a4389629 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -7,6 +7,7 @@ import time import uuid import warnings +from contextlib import contextmanager from datetime import datetime from itertools import cycle, islice from queue import Queue @@ -608,3 +609,20 @@ def add_sub_dict(prefix: str, sub_dict: Dict[str, Any]): out[k] = v return out + + +@contextmanager +def cuda_sync_debug_mode(debug_mode: Union[int, str]): + """ + A context manager for temporarily setting the CUDA sync debug mode. + """ + current_mode: Optional[int] = None + + try: + if torch.cuda.is_available(): + current_mode = torch.cuda.get_sync_debug_mode() + torch.cuda.set_sync_debug_mode(debug_mode) + yield + finally: + if current_mode is not None: + torch.cuda.set_sync_debug_mode(debug_mode) diff --git a/src/test/nn/attention_test.py b/src/test/nn/attention_test.py index 398e74aa..3d18d266 100644 --- a/src/test/nn/attention_test.py +++ b/src/test/nn/attention_test.py @@ -20,7 +20,7 @@ ) @pytest.mark.parametrize( "n_kv_heads", - [pytest.param(None, id="MHA"), pytest.param(1, id="MQA"), pytest.param(4, id="GQA")], + [pytest.param(None, id="MHA"), pytest.param(1, id="MQA"), pytest.param(2, id="GQA")], ) @pytest.mark.parametrize( "use_flash", @@ -55,7 +55,7 @@ def test_attention( attention = Attention( d_model=d_model, - n_heads=8, + n_heads=4, n_kv_heads=n_kv_heads, use_flash=use_flash, init_device=device.type,