From e27ba74550f12a685b95aded46c88f845a7e2d6d Mon Sep 17 00:00:00 2001 From: Pete Walsh Date: Tue, 5 Nov 2024 12:25:19 -0800 Subject: [PATCH] Update throughput numbers, add `logging_configured()` util function (#81) --- CHANGELOG.md | 1 + README.md | 4 ++-- src/olmo_core/distributed/utils.py | 16 +++++++++++++++- src/olmo_core/utils.py | 17 ++++++++++++++--- src/scripts/train/OLMo-1B.py | 2 +- 5 files changed, 33 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d7d9b8c..58888aae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `retries` field to `BeakerLaunchConfig`. - Allow running on Augusta cluster with existing train scripts. +- Added `olmo_core.utils.logging_configured()` function to check if logging has been configured. ## [v1.6.0](https://github.com/allenai/OLMo-core/releases/tag/v1.6.0) - 2024-11-01 diff --git a/README.md b/README.md index 01a3b28f..3effcfc6 100644 --- a/README.md +++ b/README.md @@ -39,8 +39,8 @@ Throughput numbers from these scripts with various different configuration setti | Model size | Model arch | Context length | Precision | Throughput[^1] | Training script | Commandline overrides                                    | | :--------: | :--------: | :------------: | :-------: | -----------: | :----------- | :-------- | -| **1B** | OLMo-1124 | 4096 | BF16 | 44,000 TPS | `OLMo-1B.py` | | -| | | 4096 | BF16/FP8[^2] | 51,000 TPS | `OLMo-1B.py` | `--model.float8_config.enabled=true` | +| **1B** | OLMo-1124 | 4096 | BF16 | 55,000 TPS | `OLMo-1B.py` | | +| | | 4096 | BF16/FP8[^2] | 65,000 TPS | `OLMo-1B.py` | `--model.float8_config.enabled=true` | | **7B** | OLMo-1124 | 4096 | BF16 | 10,000 TPS | `OLMo-7B.py` | | | | | 4096 | BF16/FP8 | 13,000 TPS | `OLMo-7B.py` | `--model.float8_config.enabled=true` | | **8B** | Llama | 4096 | BF16 | 9,500 TPS | `Llama-8B.py` | | diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 92ff8f55..2bd47e5b 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -2,6 +2,7 @@ Distributed helpers, most of which work in a non-distributed context as well for API unity. """ +import logging import os from datetime import timedelta from typing import List, Optional, TypeVar @@ -13,7 +14,7 @@ from ..config import StrEnum from ..exceptions import OLMoConfigurationError, OLMoEnvironmentError -from ..utils import get_default_device, move_to_device, set_env_var +from ..utils import get_default_device, logging_configured, move_to_device, set_env_var OLMO_SHARED_FS_ENV_VAR = "OLMO_SHARED_FS" OLMO_FS_LOCAL_RANK_ENV_VAR = "FS_LOCAL_RANK" @@ -23,6 +24,9 @@ BEAKER_HOSTNAME_ENV_VAR = "BEAKER_NODE_HOSTNAME" +log = logging.getLogger(__name__) + + def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minutes=30)): """ Initialize the distributed process group with the given backend(s) and check/set the @@ -100,6 +104,16 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut validate_env_vars() + msg = ( + f"Global rank {get_rank()} " + f"= local rank {get_local_rank()} " + f"= file system local rank {get_fs_local_rank()}" + ) + if logging_configured(): + log.warning(msg) + else: + print(msg) + def validate_env_vars(): """ diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index a4389629..90cdfd64 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -171,19 +171,18 @@ def has_flash_attn() -> bool: def set_env_var(name: str, value: str, override: bool = False, secret: bool = False): - global _LOGGING_CONFIGURED value_str = "****" if secret else value if name in os.environ: if override and os.environ[name] != value: msg = f"Overriding env var '{name}' to '{value_str}'" - if _LOGGING_CONFIGURED: + if logging_configured(): log.warning(msg) else: print(msg) os.environ[name] = value else: msg = f"Setting env var '{name}' to '{value_str}'" - if _LOGGING_CONFIGURED: + if logging_configured(): log.info(msg) else: print(msg) @@ -314,6 +313,18 @@ def local_rank0_filter(record: logging.LogRecord) -> int: _LOGGING_CONFIGURED = True +def logging_configured() -> bool: + """ + Returns ``True`` if logging has been configured (like with :func:`setup_logging()`), + otherwise returns ``False``. + """ + if _LOGGING_CONFIGURED: + return True + else: + # Otherwise check if the root logger has any handlers. + return len(logging.getLogger().handlers) > 0 + + def excepthook(exctype, value, traceback): """ Used to patch ``sys.excepthook`` in order to log exceptions. Use :func:`install_excepthook()` diff --git a/src/scripts/train/OLMo-1B.py b/src/scripts/train/OLMo-1B.py index 7f222cee..b4bc97c5 100644 --- a/src/scripts/train/OLMo-1B.py +++ b/src/scripts/train/OLMo-1B.py @@ -38,7 +38,7 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig: return ( TrainerConfig( save_folder=common.save_folder, - rank_microbatch_size=4 * 4096, + rank_microbatch_size=8 * 4096, save_overwrite=True, metrics_collect_interval=10, cancel_check_interval=1,