diff --git a/README.md b/README.md index afcbd51..642fbf2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ --- ```bash -/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/unconst/boltzmann/master/run.sh)" +/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/unconst/boltzmann/master/scripts/run.sh)" ``` --- diff --git a/boltz/common.py b/boltz/common.py new file mode 100644 index 0000000..11d7c74 --- /dev/null +++ b/boltz/common.py @@ -0,0 +1,698 @@ +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import os +import io +import sys +import uuid +import time +import fcntl +import torch +import uvloop +import hashlib +import asyncio +import logging +import tempfile +import aiofiles +import numpy as np +import aiobotocore +import bittensor as bt +import botocore.config +from typing import List, Dict +from dotenv import dotenv_values +from types import SimpleNamespace +from rich.logging import RichHandler +from filelock import FileLock, Timeout +from aiobotocore.session import get_session +from rich.highlighter import NullHighlighter +import torch +import hashlib +import numpy as np +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import FullStateDictConfig, StateDictType +import torch.distributed as dist +from torch import nn +import traceback + +# Configure loguru logger +FORMAT = "%(message)s" +logging.basicConfig( + level=logging.INFO, + format=FORMAT, + datefmt="[%X]", + handlers=[ + RichHandler( + markup=True, + rich_tracebacks=True, + highlighter=NullHighlighter(), + show_level=False, + show_time=False, + show_path=False, + ) + ], +) +logger = logging.getLogger("rich") +logger.setLevel(logging.INFO) + + +def debug(): + logger.setLevel(logging.DEBUG) + + +def trace(): + logger.setLevel(logging.TRACE) + + +# Log helper. +def T(): + return time.time() + + +def P(w, d): + return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])" + + +# Load environment variables +env_config = {**dotenv_values(".env"), **os.environ} +AWS_ACCESS_KEY_ID = env_config.get("AWS_ACCESS_KEY_ID") +AWS_SECRET_ACCESS_KEY = env_config.get("AWS_SECRET_ACCESS_KEY") + +# Configure the S3 client +client_config = botocore.config.Config( + max_pool_connections=256, +) + +# Set uvloop as the event loop policy +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Define a semaphore to limit concurrent downloads (adjust as needed) +semaphore = asyncio.Semaphore(1000) + + +async def get_slices(filename: str, device: str) -> Dict[str, torch.Tensor]: + # Attempt to acquire the lock with a timeout of 1 second. + lock: FileLock = FileLock(f"{filename}.lock") + with lock.acquire(timeout=5): + pass + return torch.load( + filename, + map_location=torch.device(device), + weights_only=True, + ) + +async def apply_slices_to_model( + model: nn.Module, window: int, seed: str, compression: int, key: str = "slice" +) -> List[str]: + rank = dist.get_rank() + world_size = dist.get_world_size() + logger.debug(f"Rank {rank}: Starting apply_slices_to_model for window {window}") + + # Get indices associated with the window (all ranks must participate) + try: + indices_dict: Dict[str, torch.LongTensor] = await get_indices_for_window( + model, seed, compression + ) + except Exception as e: + logger.exception(f"Rank {rank}: Failed to get indices: {e}") + sys.exit(1) # Ensure all ranks exit to prevent hangs + + logger.debug(f"Rank {rank}: Retrieved indices for parameters") + + # Load slice files on rank 0 and broadcast the list to all ranks + if rank == 0: + try: + slice_files: List[str] = await load_files_for_window(window=window, key=key) + logger.debug(f"Rank {rank}: Loaded {len(slice_files)} slice files for window {window}") + except Exception as e: + logger.exception(f"Rank {rank}: Failed to load slice files: {e}") + slice_files = [] + else: + slice_files = None # Placeholder for other ranks + + # Broadcast slice_files from rank 0 to all ranks + try: + slice_files_list = [slice_files] if rank == 0 else [None] + dist.broadcast_object_list(slice_files_list, src=0) + slice_files = slice_files_list[0] + logger.debug(f"Rank {rank}: Received slice_files: {slice_files}") + except Exception as e: + logger.exception(f"Rank {rank}: Failed to broadcast slice files: {e}") + sys.exit(1) # Ensure all ranks exit + + if not slice_files: + logger.warning(f"Rank {rank}: No slice files to process for window {window}") + return slice_files # All ranks return here synchronously + + # Initialize dictionaries to keep track of sums and counts + param_sums: Dict[str, torch.Tensor] = {} + slices_per_param: Dict[str, int] = {} + + # Rank 0 processes the slice files and reconstructs the parameters + if rank == 0: + for name in indices_dict.keys(): + param_sums[name] = torch.zeros( + indices_dict[name].numel(), dtype=torch.float32 + ) + slices_per_param[name] = 0 + + # Process each slice file + for file_i in slice_files: + logger.debug(f"Rank {rank}: Processing slice file {file_i}") + try: + slice_i = await get_slices(file_i, "cpu") # Load slices to CPU + for name in slice_i.keys(): + if name in param_sums: + param_sums[name] += slice_i[name].cpu() + slices_per_param[name] += 1 + except Exception as e: + logger.exception( + f"Rank {rank}: Error processing slice file {file_i}: {e}" + ) + continue + + # Average the sums to get the updated parameters + for name in param_sums.keys(): + if slices_per_param[name] > 0: + param_sums[name] /= slices_per_param[name] + else: + logger.warning(f"Rank {rank}: No slices applied for parameter {name}") + + # Broadcast param_sums and slices_per_param from rank 0 to all ranks + try: + if rank == 0: + data_to_broadcast = (param_sums, slices_per_param) + else: + data_to_broadcast = None + + data_list = [data_to_broadcast] + dist.broadcast_object_list(data_list, src=0) + if rank != 0: + param_sums, slices_per_param = data_list[0] + + logger.debug(f"Rank {rank}: Received param_sums and slices_per_param") + except Exception as e: + logger.exception(f"Rank {rank}: Failed to broadcast parameter sums: {e}") + sys.exit(1) # Ensure all ranks exit + + # All ranks proceed to update the model parameters + try: + # Retrieve the full state_dict (all ranks must participate) + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + state_dict: Dict[str, torch.Tensor] = model.state_dict() + + for name, param in state_dict.items(): + if ( + name in indices_dict + and name in param_sums + and slices_per_param[name] > 0 + ): + indices = indices_dict[name].to("cpu") + updated_values = param_sums[name].to(param.dtype) + # Update the parameter values at the specified indices + param.view(-1)[indices] = updated_values + logger.trace(f"Rank {rank}: Updated parameter {name}") + else: + logger.trace(f"Rank {rank}: No updates applied to parameter {name}") + + # Broadcast the updated state_dict from rank 0 to all other ranks + state_dict_list = [state_dict] if rank == 0 else [None] + dist.broadcast_object_list(state_dict_list, src=0) + state_dict = state_dict_list[0] + logger.debug(f"Rank {rank}: Received updated state_dict from broadcast") + + # Load the updated state_dict back into the model (all ranks must participate) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + model.load_state_dict(state_dict, strict=False) + logger.debug(f"Rank {rank}: Model parameters updated") + + except Exception as e: + logger.exception(f"Rank {rank}: Failed to update model parameters: {e}") + sys.exit(1) # Ensure all ranks exit + + return slice_files + + +async def upload_slice_for_window( + bucket: str, + model: torch.nn.Module, + window: int, + seed: str, + wallet, + compression: int, + key: str = "slice", +): + """ + Uploads a compressed slice of an FSDP model to a storage bucket. + + Args: + bucket (str): Name of the storage bucket. + model (torch.nn.Module): The FSDP-wrapped PyTorch model to be sliced and uploaded. + window (int): The window identifier. + seed (str): The seed used for generating indices. + wallet (bt.wallet): The wallet object containing the hotkey. + compression (int): The compression factor. + key (str): The key used to identify the slices. + + Returns: + None + + Example Usage: + await upload_slice_for_window( + bucket='my-bucket', + model=my_fsdp_model, + window=42, + seed='1234', + wallet=my_wallet, + compression=10, + key='slice' + ) + + Notes: + - This function ensures that all ranks participate in necessary collective operations with FSDP. + - Only Rank 0 performs the actual upload to the storage bucket. + - All ranks must participate in collective operations to prevent hangs. + """ + rank = dist.get_rank() + logger.debug(f"Rank {rank}: Starting upload_slice_for_window") + + # Generate the filename based on the window and wallet hotkey + filename = f"{key}-{window}-{wallet.hotkey.ss58_address}.pt" + logger.debug(f"Rank {rank}: Filename for slice: {filename}") + + try: + # Get indices for slicing (all ranks must participate) + indices = await get_indices_for_window(model, seed, compression) + logger.debug(f"Rank {rank}: Retrieved indices for slicing") + + # Retrieve the full state_dict (all ranks must participate) + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + state_dict = model.state_dict() + logger.debug(f"Rank {rank}: Retrieved state_dict for slicing") + + # Prepare the sliced state_dict + slice_dict = {} + for name, param in state_dict.items(): + if name in indices: + param_indices = indices[name].to("cpu") + sliced_param = param.view(-1)[param_indices].cpu() + slice_dict[name] = sliced_param + logger.trace(f"Rank {rank}: Sliced parameter {name}") + else: + logger.trace(f"Rank {rank}: No indices for parameter {name}; skipping") + except Exception as e: + logger.exception(f"Rank {rank}: Error during slicing: {e}") + sys.exit(1) # Ensure all ranks exit to prevent hangs + + # Only Rank 0 saves and uploads the slice + if rank == 0: + try: + # Save the slice_dict to a temporary file + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + torch.save(slice_dict, temp_file) + temp_file_name = temp_file.name + logger.debug(f"Rank {rank}: Saved slice to temporary file {temp_file_name}") + + # Initialize S3 client + session = get_session() + async with session.create_client( + "s3", + region_name="us-east-1", # Replace with your region + config=client_config, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) as s3_client: + try: + # Upload the file to S3 + with open(temp_file_name, "rb") as f: + await s3_client.put_object(Bucket=bucket, Key=filename, Body=f) + logger.debug( + f"Rank {rank}: Uploaded slice to bucket {bucket} with key {filename}" + ) + + # Optionally set object ACL to public-read + await s3_client.put_object_acl( + Bucket=bucket, Key=filename, ACL="public-read" + ) + logger.debug( + f"Rank {rank}: Set object ACL to public-read for {filename}" + ) + except Exception as e: + logger.exception( + f"Rank {rank}: Failed to upload slice to storage: {e}" + ) + finally: + # Clean up the temporary file + os.remove(temp_file_name) + logger.debug( + f"Rank {rank}: Removed temporary file {temp_file_name}" + ) + except Exception as e: + logger.exception( + f"Rank {rank}: Error during saving or uploading slice: {e}" + ) + sys.exit(1) # Ensure all ranks exit to prevent hangs + else: + logger.debug( + f"Rank {rank}: Slice preparation complete. Waiting for Rank 0 to upload." + ) + + # Synchronize all ranks to ensure upload is completed before proceeding + dist.barrier() + logger.debug(f"Rank {rank}: Completed upload_slice_for_window") + + +async def get_indices_for_window( + model: torch.nn.Module, seed: str, compression: int +) -> Dict[str, torch.LongTensor]: + """ + Computes the indices for the given window and compression factor. + + Args: + model (torch.nn.Module): The PyTorch model. + seed (str): The window seed identifier. + compression (int): The compression factor. + + Returns: + Dict[str, torch.LongTensor]: A dictionary mapping parameter names to index tensors. + """ + logger.debug( + f"Starting get_indices_for_window with seed={seed}, compression={compression}" + ) + result = {} + + # Seed the random number generator with the seed + seed_int = int(hashlib.md5(str(seed).encode("utf-8")).hexdigest(), 16) % (2**32) + logger.trace(f"Converted seed '{seed}' to integer: {seed_int}") + rng = np.random.default_rng(seed_int) + logger.trace(f"Initialized random number generator with seed: {seed_int}") + + # Retrieve the full state dict to get parameter shapes + logger.trace("Retrieving full state dict") + dist.barrier() + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + state_dict = model.state_dict() + logger.trace(f"Retrieved state dict with {len(state_dict)} parameters") + + # For each parameter, compute the indices + for name, param in state_dict.items(): + logger.trace(f"Processing parameter: {name}") + numel = param.numel() + logger.trace(f"Parameter {name} has {numel} elements") + num_indices = max(1, int(numel // compression)) + logger.trace(f"Selecting {num_indices} indices for parameter {name}") + indices = rng.choice(numel, size=num_indices, replace=False) + logger.trace( + f"Generated indices for {name}: min={indices.min()}, max={indices.max()}, shape={indices.shape}" + ) + result[name] = torch.from_numpy(indices).long().cpu() + logger.trace(f"Converted indices for {name} to torch.LongTensor on CPU") + + logger.trace( + f"Finished get_indices_for_window, returning dict with {len(result)} entries" + ) + return result + + +async def download_file(s3_client, bucket: str, filename: str) -> str: + """ + Downloads a file from S3, using parallel downloads for large files. + + Args: + s3_client: The S3 client. + bucket (str): Name of the S3 bucket. + filename (str): The S3 object key (filename). + + Returns: + str: The path to the downloaded file in the temporary directory. + """ + async with semaphore: + temp_file = os.path.join(tempfile.gettempdir(), filename) + # Check if the file exists. + if os.path.exists(temp_file): + logger.debug(f"File {temp_file} already exists, skipping download.") + return temp_file + lock_file = f"{temp_file}.lock" + lock = FileLock(lock_file) + try: + # Try to acquire both locks with a timeout + with lock.acquire(timeout=1): + # Proceed to download the file + logger.debug(f"Downloading file {filename} to {temp_file}") + head_response = await s3_client.head_object(Bucket=bucket, Key=filename) + object_size = head_response["ContentLength"] + CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB + + response = await s3_client.get_object(Bucket=bucket, Key=filename) + async with aiofiles.open(temp_file, "wb") as outfile: + while True: + chunk = await response["Body"].read(CHUNK_SIZE) + if not chunk: + break + await outfile.write(chunk) + + logger.debug(f"Successfully downloaded file {filename} to {temp_file}") + return temp_file + + except Timeout: + logger.error( + f"Timeout occurred while trying to acquire lock on {lock_file}" + ) + return None + except Exception as e: + logger.exception( + f"Failed to download file {filename} from bucket {bucket}: {e}" + ) + return None + finally: + # The lock is automatically released when exiting the 'with' block + pass + + +async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window: int): + """ + Handles downloading a single file from S3. + + Args: + s3_client: The S3 client. + bucket (str): Name of the S3 bucket. + filename (str): The S3 object key (filename). + hotkey (str): The hotkey identifier. + window (int): The window identifier. + + Returns: + SimpleNamespace: An object containing file metadata and the path to the downloaded file. + """ + logger.debug(f"Handling file {filename} for window {window} and hotkey {hotkey}") + temp_file = await download_file(s3_client, bucket, filename) + if temp_file: + return SimpleNamespace( + bucket=bucket, + hotkey=hotkey, + filename=filename, + window=window, + temp_file=temp_file, + ) + return None + + +async def process_bucket( + s3_client, bucket: str, windows: List[int], key: str = "slice" +): + """ + Processes an S3 bucket to download files matching the given windows. + + Args: + s3_client: The S3 client. + bucket (str): Name of the S3 bucket. + windows (List[int]): A list of window identifiers. + + Returns: + List[SimpleNamespace]: A list of file metadata and paths for downloaded files. + """ + logger.debug(f"Processing bucket {bucket} for window {windows}") + files = [] + paginator = s3_client.get_paginator("list_objects_v2") + + for window in windows: + prefix = f"{key}-{window}" + logger.debug(f"Listing objects with prefix {prefix}") + async for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + logger.trace(f"Processing page for prefix {prefix}") + if "Contents" not in page: + logger.trace(f"No contents found for prefix {prefix}") + continue + download_tasks = [] + for obj in page.get("Contents", []): + filename = obj["Key"] + logger.trace(f"Processing object with key {filename}") + try: + parts = filename.split("-") + slice_window = int(parts[1]) + slice_hotkey = parts[2].split(".")[0] + logger.trace( + f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}" + ) + if slice_window == window: + download_tasks.append( + handle_file( + s3_client, bucket, filename, slice_hotkey, slice_window + ) + ) + except Exception: + logger.exception(f"Error processing filename {filename}") + continue + # Download the files concurrently + results = await asyncio.gather(*download_tasks) + files.extend([res for res in results if res]) + logger.trace(f"Completed processing page for prefix {prefix}") + logger.trace(f"Completed processing bucket {bucket} for windows {windows}") + return files + + +async def download_slices_for_buckets_and_windows( + buckets: List[str], windows: List[int], key: str = "slice" +) -> Dict[int, List[SimpleNamespace]]: + """ + Downloads files from multiple S3 buckets for the given windows. + + Args: + buckets (List[str]): A list of S3 bucket names. + windows (List[int]): A list of window identifiers. + + Returns: + Dict[int, List[SimpleNamespace]]: A dictionary mapping windows to lists of file metadata and paths. + """ + logger.debug(f"Downloading files for buckets {set(buckets)} and windows {windows}") + session = get_session() + async with session.create_client( + "s3", + region_name="us-east-1", + config=client_config, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) as s3_client: + tasks = [] + for bucket in set(buckets): + if not bucket: + continue + tasks.append(process_bucket(s3_client, bucket, windows, key)) + results = await asyncio.gather(*tasks) + # Flatten the list of lists + files = [item for sublist in results for item in sublist] + + # Create a dictionary with windows as keys and list of files as values + windows_dict = {} + for file in files: + window = file.window + if window not in windows_dict: + windows_dict[window] = [] + windows_dict[window].append(file) + + logger.debug(f"Downloaded all files grouped by windows: {windows}") + return windows_dict + + +async def load_files_for_window(window: int, key: str = "slice") -> List[str]: + """ + Retrieves the paths to downloaded window files from the temporary directory. + + Args: + window (int): The window identifier. + + Returns: + List[str]: A list of file paths corresponding to the window. + """ + logger.debug(f"Retrieving files for window {window} from temporary directory") + temp_dir = tempfile.gettempdir() + window_files = [] + for filename in os.listdir(temp_dir): + if filename.startswith(f"{key}-{window}-") and filename.endswith(".pt"): + window_files.append(os.path.join(temp_dir, filename)) + logger.debug(f"Found file {filename} for window {window}") + return window_files + + +async def delete_files_before_window(window_max: int, key: str = "slice"): + """ + Deletes all files on the local machine which have a window id before a specific value window_max. + + Args: + window_max (int): The maximum window id. Files with window ids less than this value will be deleted. + """ + logger.debug(f"Deleting files with window id before {window_max}") + temp_dir = tempfile.gettempdir() + for filename in os.listdir(temp_dir): + if filename.startswith(f"{key}-") and ( + filename.endswith(".pt") or filename.endswith(".lock") + ): + try: + parts = filename.split("-") + window_id = int(parts[1]) + if window_id < window_max: + if os.path.exists(filename): + os.remove(filename) + logger.debug(f"Deleted file {filename}") + except Exception as e: + logger.error(f"Error deleting file {filename}: {e}") + + +async def delete_files_from_bucket_before_window( + bucket: str, window_max: int, key: str = "slice" +): + """ + Deletes all files in the specified S3 bucket which have a window id before a specific value window_max. + + Args: + bucket (str): The name of the S3 bucket. + window_max (int): The maximum window id. Files with window ids less than this value will be deleted. + """ + logger.debug( + f"Deleting files in bucket {bucket} with window id before {window_max}" + ) + session = get_session() + async with session.create_client( + "s3", + region_name="us-east-1", + config=client_config, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, + ) as s3_client: + try: + response = await s3_client.list_objects_v2(Bucket=bucket) + if "Contents" in response: + for obj in response["Contents"]: + filename = obj["Key"] + if filename.startswith(f"{key}-") and filename.endswith(".pt"): + try: + parts = filename.split("-") + window_id = int(parts[1]) + if window_id < window_max: + await s3_client.delete_object( + Bucket=bucket, Key=filename + ) + logger.debug( + f"Deleted file {filename} from bucket {bucket}" + ) + except Exception as e: + logger.error( + f"Error deleting file {filename} from bucket {bucket}: {e}" + ) + except Exception as e: + logger.error(f"Error listing objects in bucket {bucket}: {e}") diff --git a/dataset.py b/boltz/dataset.py similarity index 100% rename from dataset.py rename to boltz/dataset.py diff --git a/boltz/fsdp.py b/boltz/fsdp.py new file mode 100644 index 0000000..abe8587 --- /dev/null +++ b/boltz/fsdp.py @@ -0,0 +1,144 @@ +from collections import defaultdict +from typing import Callable, List, Tuple + +import os +import torch +import torch.nn as nn +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +import functools +from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +import logging + +from common import logger + +def fsdp_auto_wrap_policy(model, transformer_layer_names): + + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=set(transformer_layer_names) + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy + + +def wrap_model_with_fsdp( + model: nn.Module, + device_id: int, + transformer_layer_names: List[nn.Module], + is_distributed: bool, + device: torch.device +) -> nn.Module: + """ + Wraps a PyTorch model with Fully Sharded Data Parallel (FSDP) if distributed training is enabled. + + Args: + model (nn.Module): The PyTorch model to wrap + device_id (int): The local device ID for FSDP + transformer_layer_names (List[nn.Module]): List of transformer layer classes to wrap + is_distributed (bool): Whether distributed training is enabled + device (torch.device): The device to move the model to + + Returns: + nn.Module: The wrapped model, either with FSDP or moved to device + + Example: + >>> model = LlamaForCausalLM(config) + >>> wrapped_model = wrap_model_with_fsdp( + ... model=model, + ... device_id=local_rank, + ... transformer_layer_names=[LlamaDecoderLayer], + ... is_distributed=dist.is_initialized(), + ... device=torch.device("cuda", local_rank) + ... ) + """ + if is_distributed: + # Create the custom auto wrap policy + auto_wrap_policy = fsdp_auto_wrap_policy( + model=model, + transformer_layer_names=transformer_layer_names + ) + + # Wrap with FSDP + model = FSDP( + model, + device_id=device_id, + auto_wrap_policy=auto_wrap_policy, + sharding_strategy=ShardingStrategy.FULL_SHARD, + ) + logger.info(f"Model wrapped with FSDP on device {device} using custom auto wrap policy.") + else: + # Move to device for single-process execution + model = model.to(device) + logger.info(f"Model moved to device {device}.") + + return model + +def initialize_distributed_training( + device_config: str, +) -> Tuple[int, int, int, torch.device]: + """ + Initializes distributed training configuration and device settings. + + Args: + device_config (str): Device configuration string (e.g. 'cuda:0') + logger (logging.Logger): Logger instance for output messages + + Returns: + Tuple containing: + local_rank (int): Local process rank + global_rank (int): Global process rank + world_size (int): Total number of processes + device (torch.device): Torch device object + + Example: + >>> local_rank, global_rank, world_size, device = initialize_distributed_training( + ... device_config='cuda:0', + ... logger=logging.getLogger() + ... ) + """ + # Initialize distributed training if CUDA available and running in distributed mode + if torch.cuda.is_available() and "LOCAL_RANK" in os.environ: + # Get distributed training parameters from environment + local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Validate GPU availability + num_gpus = torch.cuda.device_count() + if local_rank >= num_gpus: + raise ValueError(f"Local rank {local_rank} exceeds number of available GPUs {num_gpus}.") + + # Configure device + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + # Initialize process group + dist.init_process_group(backend='nccl') + logger.info(f"Distributed training initialized on rank {global_rank} out of {world_size} processes.") + + else: + # Single process execution settings + local_rank = 0 + global_rank = 0 + world_size = 1 + device = torch.device(device_config) + logger.warning("Distributed training is not initialized. Running on a single process.") + + return local_rank, global_rank, world_size, device diff --git a/hparams.py b/boltz/hparams.py similarity index 99% rename from hparams.py rename to boltz/hparams.py index 2dde865..744d2e1 100644 --- a/hparams.py +++ b/boltz/hparams.py @@ -22,7 +22,7 @@ from types import SimpleNamespace from transformers import AutoTokenizer, LlamaConfig -from common import * +from boltz.common import * # Cache file path HPARAMS_FILE = "hparams.json" diff --git a/common.py b/common.py deleted file mode 100644 index 5f12be6..0000000 --- a/common.py +++ /dev/null @@ -1,502 +0,0 @@ -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -import os -import io -import sys -import uuid -import time -import fcntl -import torch -import uvloop -import hashlib -import asyncio -import logging -import tempfile -import aiofiles -import numpy as np -import aiobotocore -import bittensor as bt -import botocore.config -from typing import List, Dict -from dotenv import dotenv_values -from types import SimpleNamespace -from rich.logging import RichHandler -from filelock import FileLock, Timeout -from aiobotocore.session import get_session -from rich.highlighter import NullHighlighter - -# Configure loguru logger -FORMAT = "%(message)s" -logging.basicConfig( - level=logging.INFO, - format=FORMAT, - datefmt="[%X]", - handlers=[ - RichHandler( - markup=True, - rich_tracebacks=True, - highlighter=NullHighlighter(), - show_level=False, - show_time=False, - show_path=False - ) - ] -) -logger = logging.getLogger("rich") -logger.setLevel(logging.INFO) -def debug(): - logger.setLevel(logging.DEBUG) -def trace(): - logger.setLevel(logging.TRACE) -# Log helper. -def T(): return time.time() -def P( w, d ): return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])" - -# Load environment variables -env_config = {**dotenv_values(".env"), **os.environ} -AWS_ACCESS_KEY_ID = env_config.get('AWS_ACCESS_KEY_ID') -AWS_SECRET_ACCESS_KEY = env_config.get('AWS_SECRET_ACCESS_KEY') - -# Configure the S3 client -client_config = botocore.config.Config( - max_pool_connections=256, -) - -# Set uvloop as the event loop policy -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -# Define a semaphore to limit concurrent downloads (adjust as needed) -semaphore = asyncio.Semaphore(1000) - -async def get_slices( filename:str, device:str ) -> Dict[str, torch.Tensor]: - # Attempt to acquire the lock with a timeout of 1 second. - lock: FileLock = FileLock(f"{filename}.lock") - with lock.acquire(timeout=5): - pass - return torch.load( - filename, - map_location=torch.device(device), - weights_only = True, - ) - -async def apply_slices_to_model(model: torch.nn.Module, window: int, seed: str, compression: int, key:str = 'slice') -> List[str]: - """ - Applies slices from a specific window to the given model. - - Args: - model (torch.nn.Module): The PyTorch model to which the slices will be applied. - window (int): The window identifier. - seed (str): The seed used for generating indices. - compression (int): The compression factor. - - Returns: - List[str]: A list of all the slice files that were applied. - """ - # First get the indices associated with the window given the model. - indices_dict = await get_indices_for_window(model, seed, compression) - - # Load all the slices associated with this window. - slice_files = await load_files_for_window(window=window, key = key) - - # Dictionary to keep track of the number of slices applied per parameter. - slices_per_param = {name: 0 for name, _ in model.named_parameters()} - - # Dictionary to accumulate the sum of values for each parameter. - param_sums = {name: torch.zeros_like(param.data) for name, param in model.named_parameters()} - - # Iterate over each slice file and compute the sum of values. - for file_i in slice_files: - # Create a file lock to ensure exclusive access to the slice file. - try: - slice_i = await get_slices(file_i, model.device) - for name, param in model.named_parameters(): - if name not in indices_dict or name not in slice_i: - continue - values = slice_i[name].to(param.data.device) - param_indices = indices_dict[name].to(param.data.device) - param_sums[name].view(-1)[param_indices] += values - slices_per_param[name] += 1 - del values - del slice_i - except Timeout: - # The lock could not be acquired within the timeout. - logger.error(f"Timeout occurred while trying to acquire lock on {file_i}") - continue - except Exception as e: - logger.exception(f"Error applying slice from {file_i}: {e}") - - # Apply the average to the parameters. - for name, param in model.named_parameters(): - if name not in slices_per_param or name not in indices_dict or slices_per_param[name] == 0: - continue - param_indices = indices_dict[name].to(param.data.device) - avg_param = param_sums[name].view(-1)[param_indices] / slices_per_param[name] - avg_param = avg_param.to(param.data.dtype) - avg_param = avg_param.to(param.data.device) - param.data.view(-1)[param_indices] = avg_param.clone() - - # Return the list of the files applied. - return slice_files - -async def upload_slice_for_window(bucket: str, model: torch.nn.Module, window: int, seed: str, wallet: 'bt.wallet', compression: int, key:str = 'slice'): - """ - Uploads a compressed slice of a PyTorch model to an S3 bucket. - - Args: - bucket (str): Name of the S3 bucket. - model (torch.nn.Module): The PyTorch model to be sliceed and uploaded. - window (int): The window identifier. - wallet (bt.wallet): The wallet object containing the hotkey. - compression (int): The compression factor. - """ - filename = f'{key}-{window}-{wallet.hotkey.ss58_address}.pt' - logger.debug(f"Uploading slice to S3: {filename}") - - model_state_dict = model.state_dict() - indices = await get_indices_for_window(model, seed, compression) - - # Apply the slice to the model parameters - for name, param in model.named_parameters(): - model_state_dict[name] = param.data.view(-1)[indices[name].to(model.device)].cpu() - - # Create a temporary file and write the sliceed model state dictionary to it - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - torch.save(model_state_dict, temp_file) - temp_file_name = temp_file.name # Store the temporary file name - - # Upload the file to S3 - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - try: - with open(temp_file_name, 'rb') as f: - await s3_client.put_object(Bucket=bucket, Key=filename, Body=f) - # Set the object ACL to public-read - await s3_client.put_object_acl( - Bucket=bucket, - Key=filename, - ACL='public-read' - ) - logger.debug(f"Successfully uploaded slice to S3: {filename}") - except Exception: - logger.exception(f"Failed to upload slice {filename} to S3") - finally: - # Clean up the temporary file - os.remove(temp_file_name) - logger.debug(f"Temporary file {temp_file_name} removed") - -async def upload_master(bucket: str, model: torch.nn.Module, wallet: 'bt.wallet'): - """ - Uploads the master PyTorch model to an S3 bucket. - - Args: - bucket (str): Name of the S3 bucket. - model (torch.nn.Module): The PyTorch model to be uploaded. - wallet (bt.wallet): The wallet object containing the hotkey. - """ - upload_filename = f'master-{wallet.hotkey.ss58_address}.pt' - logger.debug(f"Uploading master model to S3: {upload_filename}") - - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - try: - # Create a temporary file and write the model state dictionary to it - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - torch.save(model.state_dict(), temp_file) - temp_file_name = temp_file.name - - # Upload the file to S3 - with open(temp_file_name, 'rb') as f: - await s3_client.put_object(Bucket=bucket, Key=upload_filename, Body=f) - # Set the object ACL to public-read - await s3_client.put_object_acl( - Bucket=bucket, - Key=upload_filename, - ACL='public-read' - ) - logger.debug(f"Successfully uploaded master model to S3: {upload_filename}") - except Exception: - logger.exception(f"Failed to upload master model {upload_filename} to S3") - finally: - # Clean up the temporary file - os.remove(temp_file_name) - logger.debug(f"Temporary file {temp_file_name} removed") - -async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: int) -> Dict[str, torch.LongTensor]: - """ - Computes the indices for the given window and compression factor. - - Args: - model (torch.nn.Module): The PyTorch model. - seed (str): The window seed identifier. - compression (int): The compression factor. - - Returns: - Dict[str, torch.LongTensor]: A dictionary mapping parameter names to index tensors. - """ - logger.debug(f"Computing indices for window seed {seed} with compression {compression}") - result = {} - # Seed the random number generator with the seed - seed = int(hashlib.md5(str(seed).encode('utf-8')).hexdigest(), 16) % (2**32) - rng = np.random.default_rng(seed) - for name, param in model.named_parameters(): - # Randomly select indices based on the compression factor - num_indices = max(1, int(param.numel() // compression)) - indices = rng.choice(param.numel(), size=num_indices, replace=False) - result[name] = torch.from_numpy(indices).long().cpu() - return result - -async def download_file(s3_client, bucket: str, filename: str) -> str: - """ - Downloads a file from S3, using parallel downloads for large files. - - Args: - s3_client: The S3 client. - bucket (str): Name of the S3 bucket. - filename (str): The S3 object key (filename). - - Returns: - str: The path to the downloaded file in the temporary directory. - """ - async with semaphore: - temp_file = os.path.join(tempfile.gettempdir(), filename) - # Check if the file exists. - if os.path.exists(temp_file): - logger.debug(f"File {temp_file} already exists, skipping download.") - return temp_file - lock_file = f"{temp_file}.lock" - lock = FileLock(lock_file) - try: - # Try to acquire both locks with a timeout - with lock.acquire(timeout=1): - # Proceed to download the file - logger.debug(f"Downloading file {filename} to {temp_file}") - head_response = await s3_client.head_object(Bucket=bucket, Key=filename) - object_size = head_response['ContentLength'] - CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB - - response = await s3_client.get_object(Bucket=bucket, Key=filename) - async with aiofiles.open(temp_file, 'wb') as outfile: - while True: - chunk = await response['Body'].read(CHUNK_SIZE) - if not chunk: - break - await outfile.write(chunk) - - logger.debug(f"Successfully downloaded file {filename} to {temp_file}") - return temp_file - - except Timeout: - logger.error(f"Timeout occurred while trying to acquire lock on {lock_file}") - return None - except Exception as e: - logger.exception(f"Failed to download file {filename} from bucket {bucket}: {e}") - return None - finally: - # The lock is automatically released when exiting the 'with' block - pass - -async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window: int): - """ - Handles downloading a single file from S3. - - Args: - s3_client: The S3 client. - bucket (str): Name of the S3 bucket. - filename (str): The S3 object key (filename). - hotkey (str): The hotkey identifier. - window (int): The window identifier. - - Returns: - SimpleNamespace: An object containing file metadata and the path to the downloaded file. - """ - logger.debug(f"Handling file {filename} for window {window} and hotkey {hotkey}") - temp_file = await download_file(s3_client, bucket, filename) - if temp_file: - return SimpleNamespace(bucket=bucket, hotkey=hotkey, filename=filename, window=window, temp_file=temp_file) - return None - -async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = 'slice'): - """ - Processes an S3 bucket to download files matching the given windows. - - Args: - s3_client: The S3 client. - bucket (str): Name of the S3 bucket. - windows (List[int]): A list of window identifiers. - - Returns: - List[SimpleNamespace]: A list of file metadata and paths for downloaded files. - """ - logger.debug(f"Processing bucket {bucket} for window {windows}") - files = [] - paginator = s3_client.get_paginator('list_objects_v2') - - for window in windows: - prefix = f'{key}-{window}' - logger.debug(f"Listing objects with prefix {prefix}") - async for page in paginator.paginate(Bucket=bucket, Prefix=prefix): - logger.trace(f"Processing page for prefix {prefix}") - if 'Contents' not in page: - logger.trace(f"No contents found for prefix {prefix}") - continue - download_tasks = [] - for obj in page.get('Contents', []): - filename = obj['Key'] - logger.trace(f"Processing object with key {filename}") - try: - parts = filename.split('-') - slice_window = int(parts[1]) - slice_hotkey = parts[2].split('.')[0] - logger.trace(f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}") - if slice_window == window: - download_tasks.append(handle_file(s3_client, bucket, filename, slice_hotkey, slice_window)) - except Exception: - logger.exception(f"Error processing filename {filename}") - continue - # Download the files concurrently - results = await asyncio.gather(*download_tasks) - files.extend([res for res in results if res]) - logger.trace(f"Completed processing page for prefix {prefix}") - logger.trace(f"Completed processing bucket {bucket} for windows {windows}") - return files - -async def download_slices_for_buckets_and_windows(buckets: List[str], windows: List[int], key:str = 'slice') -> Dict[int, List[SimpleNamespace]]: - """ - Downloads files from multiple S3 buckets for the given windows. - - Args: - buckets (List[str]): A list of S3 bucket names. - windows (List[int]): A list of window identifiers. - - Returns: - Dict[int, List[SimpleNamespace]]: A dictionary mapping windows to lists of file metadata and paths. - """ - logger.debug(f"Downloading files for buckets {set(buckets)} and windows {windows}") - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - tasks = [] - for bucket in set(buckets): - if not bucket: - continue - tasks.append(process_bucket(s3_client, bucket, windows, key)) - results = await asyncio.gather(*tasks) - # Flatten the list of lists - files = [item for sublist in results for item in sublist] - - # Create a dictionary with windows as keys and list of files as values - windows_dict = {} - for file in files: - window = file.window - if window not in windows_dict: - windows_dict[window] = [] - windows_dict[window].append(file) - - logger.debug(f"Downloaded all files grouped by windows: {windows}") - return windows_dict - -async def load_files_for_window(window: int, key: str = 'slice') -> List[str]: - """ - Retrieves the paths to downloaded window files from the temporary directory. - - Args: - window (int): The window identifier. - - Returns: - List[str]: A list of file paths corresponding to the window. - """ - logger.debug(f"Retrieving files for window {window} from temporary directory") - temp_dir = tempfile.gettempdir() - window_files = [] - for filename in os.listdir(temp_dir): - if filename.startswith(f"{key}-{window}-") and filename.endswith(".pt"): - window_files.append(os.path.join(temp_dir, filename)) - logger.debug(f"Found file {filename} for window {window}") - return window_files - -async def delete_files_before_window(window_max: int, key:str = 'slice'): - """ - Deletes all files on the local machine which have a window id before a specific value window_max. - - Args: - window_max (int): The maximum window id. Files with window ids less than this value will be deleted. - """ - logger.debug(f"Deleting files with window id before {window_max}") - temp_dir = tempfile.gettempdir() - for filename in os.listdir(temp_dir): - if filename.startswith(f"{key}-") and ( filename.endswith(".pt") or filename.endswith(".lock") ): - try: - parts = filename.split('-') - window_id = int(parts[1]) - if window_id < window_max: - if os.path.exists(filename): - os.remove(filename) - logger.debug(f"Deleted file {filename}") - except Exception as e: - logger.error(f"Error deleting file {filename}: {e}") - -async def delete_files_from_bucket_before_window(bucket: str, window_max: int, key: str = 'slice'): - """ - Deletes all files in the specified S3 bucket which have a window id before a specific value window_max. - - Args: - bucket (str): The name of the S3 bucket. - window_max (int): The maximum window id. Files with window ids less than this value will be deleted. - """ - logger.debug(f"Deleting files in bucket {bucket} with window id before {window_max}") - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - try: - response = await s3_client.list_objects_v2(Bucket=bucket) - if 'Contents' in response: - for obj in response['Contents']: - filename = obj['Key'] - if filename.startswith(f"{key}-") and filename.endswith(".pt"): - try: - parts = filename.split('-') - window_id = int(parts[1]) - if window_id < window_max: - await s3_client.delete_object(Bucket=bucket, Key=filename) - logger.debug(f"Deleted file {filename} from bucket {bucket}") - except Exception as e: - logger.error(f"Error deleting file {filename} from bucket {bucket}: {e}") - except Exception as e: - logger.error(f"Error listing objects in bucket {bucket}: {e}") diff --git a/docker_run.sh b/docker_run.sh deleted file mode 100644 index c1023e9..0000000 --- a/docker_run.sh +++ /dev/null @@ -1,387 +0,0 @@ -#!/usr/bin/env bash - -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -set -euo pipefail - -trap 'abort "An unexpected error occurred."' ERR - -# Set up colors and styles -if [[ -t 1 ]]; then - tty_escape() { printf "\033[%sm" "$1"; } -else - tty_escape() { :; } -fi -tty_mkbold() { tty_escape "1;$1"; } -tty_blue="$(tty_mkbold 34)" -tty_red="$(tty_mkbold 31)" -tty_green="$(tty_mkbold 32)" -tty_yellow="$(tty_mkbold 33)" -tty_bold="$(tty_mkbold 39)" -tty_reset="$(tty_escape 0)" - -ohai() { - printf "${tty_blue}==>${tty_bold} %s${tty_reset}\n" "$*" -} - -pdone() { - printf "${tty_green}[✔]${tty_bold} %s${tty_reset}\n" "$*" -} - -info() { - printf "${tty_green}%s${tty_reset}\n" "$*" -} - -warn() { - printf "${tty_yellow}Warning${tty_reset}: %s\n" "$*" >&2 -} - -error() { - printf "${tty_red}Error${tty_reset}: %s\n" "$*" >&2 -} - -abort() { - error "$@" - exit 1 -} - -getc() { - local save_state - save_state="$(/bin/stty -g)" - /bin/stty raw -echo - IFS='' read -r -n 1 -d '' "$@" - /bin/stty "${save_state}" -} - -wait_for_user() { - local c - echo - echo "Press ${tty_bold}RETURN${tty_reset}/${tty_bold}ENTER${tty_reset} to continue or any other key to abort:" - getc c - # we test for \r and \n because some stuff does \r instead - if ! [[ "${c}" == $'\r' || "${c}" == $'\n' ]] - then - exit 1 - fi -} - -execute() { - ohai "Running: $*" - if ! "$@"; then - abort "Failed during: $*" - fi -} - -have_root_access() { - if [ "$EUID" -ne 0 ]; then - warn "This script requires root privileges to install packages." - return 1 - fi - return 0 -} - -clear -echo "" -echo "" -echo " ______ _____ _______ ______ _______ _______ __ _ __ _" -echo " |_____] | | | | ____/ | | | |_____| | \ | | \ |" -echo " |_____] |_____| |_____ | /_____ | | | | | | \_| | \_|" -echo " " -echo "" -echo "" - -wait_for_user - -# Install Git if not present -if ! command -v git &> /dev/null; then - ohai "Installing git ..." - if have_root_access; then - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - if [ -f /etc/os-release ]; then - . /etc/os-release - if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then - ohai "Detected Ubuntu, installing Git..." - execute apt-get update -y - execute apt-get install -y git - else - warn "Unsupported Linux distribution: $ID" - abort "Cannot install Git automatically" - fi - else - warn "Cannot detect Linux distribution" - abort "Cannot install Git automatically" - fi - elif [[ "$OSTYPE" == "darwin"* ]]; then - ohai "Detected macOS, installing Git..." - execute xcode-select --install - else - abort "Unsupported OS type: $OSTYPE" - fi - else - abort "Root access is required to install Git." - fi -else - pdone "Found Git" -fi - -# Check if we are inside the cont repository -if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then - REPO_PATH="." -else - if [ ! -d "cont" ]; then - ohai "Cloning boltzmann ..." - execute git clone https://github.com/unconst/cont - REPO_PATH="cont/" - else - REPO_PATH="cont/" - fi -fi -pdone "Pulled Boltzmann $REPO_PATH" - -# Install Node.js and npm if not present -if ! command -v npm &> /dev/null; then - ohai "Installing Node.js and npm ..." - if have_root_access; then - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - execute apt-get update -y - execute apt-get install -y curl - curl -fsSL https://deb.nodesource.com/setup_20.x | bash - - execute apt-get install -y nodejs - elif [[ "$OSTYPE" == "darwin"* ]]; then - ohai "Detected macOS, installing Node.js and npm..." - execute brew install node - else - abort "Unsupported OS type: $OSTYPE" - fi - else - abort "Root access is required to install Node.js and npm." - fi - pdone "Installed Node.js and npm" -else - pdone "Found npm" -fi - -# Install pm2 -if ! command -v pm2 &> /dev/null; then - ohai "Installing pm2 ..." - execute npm install pm2 -g - pdone "Installed pm2" -else - pdone "Found pm2" -fi - -# Install Python 3.12 if not installed -if ! command -v python3.12 &> /dev/null; then - ohai "Installing python3.12 ..." - if have_root_access; then - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - if [ -f /etc/os-release ]; then - . /etc/os-release - if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then - ohai "Detected Ubuntu, installing Python 3.12..." - execute apt-get update -y - execute apt-get install -y software-properties-common gnupg - - # Add the deadsnakes PPA manually - ohai "Adding deadsnakes PPA manually..." - execute mkdir -p /etc/apt/keyrings - execute curl -fsSL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x6A755776" | gpg --dearmor --batch --yes -o /etc/apt/keyrings/deadsnakes-archive-keyring.gpg - echo "deb [signed-by=/etc/apt/keyrings/deadsnakes-archive-keyring.gpg] http://ppa.launchpad.net/deadsnakes/ppa/ubuntu jammy main" > /etc/apt/sources.list.d/deadsnakes-ppa.list - - execute apt-get update -y - execute apt-get install -y python3.12 python3.12-venv - - else - warn "Unsupported Linux distribution: $ID" - abort "Cannot install Python 3.12 automatically" - fi - else - warn "Cannot detect Linux distribution" - abort "Cannot install Python 3.12 automatically" - fi - elif [[ "$OSTYPE" == "darwin"* ]]; then - ohai "Detected macOS, installing Python 3.12..." - execute brew install python@3.12 - else - abort "Unsupported OS type: $OSTYPE" - fi - else - abort "Root access is required to install Python 3.12." - fi - pdone "Installed python3.12" -else - pdone "Found python3.12" -fi - -touch ~/.bash_profile - -# Prompt the user for AWS credentials and inject them into the bash_profile file if not already stored -if ! grep -q "AWS_ACCESS_KEY_ID" ~/.bash_profile || ! grep -q "AWS_SECRET_ACCESS_KEY" ~/.bash_profile || ! grep -q "BUCKET" ~/.bash_profile; then - clear - warn "This script will store your AWS credentials in your ~/.bash_profile file." - warn "This is not secure and is not recommended." - read -p "Do you want to proceed? [y/N]: " proceed - if [[ "$proceed" != "y" && "$proceed" != "Y" ]]; then - abort "Aborted by user." - fi - - read -p "Enter your AWS Access Key ID: " AWS_ACCESS_KEY_ID - read -p "Enter your AWS Secret Access Key: " AWS_SECRET_ACCESS_KEY - read -p "Enter your S3 Bucket Name: " BUCKET - - echo "export AWS_ACCESS_KEY_ID=\"$AWS_ACCESS_KEY_ID\"" >> ~/.bash_profile - echo "export AWS_SECRET_ACCESS_KEY=\"$AWS_SECRET_ACCESS_KEY\"" >> ~/.bash_profile - echo "export BUCKET=\"$BUCKET\"" >> ~/.bash_profile -fi - -# Source the bash_profile file to apply the changes -source ~/.bash_profile - -pdone "Found AWS credentials" - -# Create a virtual environment if it does not exist -if [ ! -d "$REPO_PATH/venv" ]; then - ohai "Creating virtual environment at $REPO_PATH..." - execute python3.12 -m venv "$REPO_PATH/venv" -fi -pdone "Created venv at $REPO_PATH" - -if [[ -z "${VIRTUAL_ENV:-}" ]]; then - ohai "Activating virtual environment..." - source "$REPO_PATH/venv/bin/activate" -fi -pdone "Activated venv at $REPO_PATH" - -ohai "Installing requirements..." -execute pip install --upgrade pip -execute pip install -r "$REPO_PATH/requirements.txt" -pdone "Installed requirements" - -# Check for GPUs -ohai "Checking for GPUs..." -if ! command -v nvidia-smi &> /dev/null; then - warn "nvidia-smi command not found. Please ensure NVIDIA drivers are installed." - NUM_GPUS=0 -else - NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) - ohai "Number of GPUs: $NUM_GPUS" - - if [ "$NUM_GPUS" -gt 0 ]; then - ohai "GPU Memory Information:" - nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | while read -r memory; do - ohai "$((memory / 1024)) GB" - done - else - warn "No GPUs found on this machine." - fi -fi - -# Check system RAM -ohai "Checking system RAM..." -if command -v free &> /dev/null; then - TOTAL_RAM=$(free -g | awk '/^Mem:/{print $2}') - ohai "Total System RAM: $TOTAL_RAM GB" -else - warn "Cannot determine system RAM. 'free' command not found." -fi - -# Create the default key -ohai "Creating the coldkey" -if ! python3.12 -c "import bittensor as bt; w = bt.wallet(); print(w.coldkey_file.exists_on_device())" | grep -q "True"; then - execute btcli w new_coldkey --wallet.path ~/.bittensor/wallets --wallet.name default --n-words 12 --no_password -else - ohai "Default key already exists on device." -fi - -# Ensure btcli is installed -if ! command -v btcli &> /dev/null; then - abort "btcli command not found. Please ensure it is installed." -fi - -# Create hotkeys and register them -if [ "$NUM_GPUS" -gt 0 ]; then - for i in $(seq 0 $((NUM_GPUS - 1))); do - # Check if the hotkey file exists on the device - exists_on_device=$(python3.12 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); print(w.hotkey_file.exists_on_device())" 2>/dev/null) - if [ "$exists_on_device" != "True" ]; then - echo "n" | btcli wallet new_hotkey --wallet.name default --wallet.hotkey C$i --n-words 12 > /dev/null 2>&1; - else - ohai "Hotkey C$i already exists on device." - fi - - # Check if the hotkey is registered on subnet 220 - is_registered=$(python3.12 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); sub = bt.subtensor('test'); print(sub.is_hotkey_registered_on_subnet(hotkey_ss58=w.hotkey.ss58_address, netuid=220))" 2>/dev/null) - if [[ "$is_registered" != *"True"* ]]; then - ohai "Registering key on subnet 220" - btcli subnet pow_register --wallet.name default --wallet.hotkey C$i --netuid 220 --subtensor.network test --no_prompt > /dev/null 2>&1; - else - ohai "Key is already registered on subnet 220" - fi - done -else - warn "No GPUs found. Skipping hotkey creation." -fi - -ohai "Logging into wandb..." -execute wandb login - -# Delete items from bucket -PROJECT=${2:-aesop} -ohai "Cleaning bucket $BUCKET..." -execute python3.12 "$REPO_PATH/tools/clean.py" --bucket "$BUCKET" - -# Start all the processes again -if [ "$NUM_GPUS" -gt 0 ]; then - for i in $(seq 0 $((NUM_GPUS - 1))); do - # Adjust GPU index for zero-based numbering - GPU_INDEX=$i - GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | sed -n "$((i + 1))p") - if [ -z "$GPU_MEMORY" ]; then - warn "Could not get GPU memory for GPU $i" - continue - fi - # Determine batch size based on GPU memory - # This section adjusts the batch size for the miner based on the available GPU memory - # Higher memory allows for larger batch sizes, which can improve performance - if [ "$GPU_MEMORY" -ge 80000 ]; then - # For GPUs with 80GB or more memory, use a batch size of 6 - BATCH_SIZE=6 - elif [ "$GPU_MEMORY" -ge 40000 ]; then - # For GPUs with 40GB to 79GB memory, use a batch size of 3 - BATCH_SIZE=3 - elif [ "$GPU_MEMORY" -ge 20000 ]; then - # For GPUs with 20GB to 39GB memory, use a batch size of 1 - BATCH_SIZE=1 - else - # For GPUs with less than 20GB memory, also use a batch size of 1 - # This ensures that even lower-end GPUs can still participate - BATCH_SIZE=1 - fi - ohai "Starting miner on GPU $GPU_INDEX with batch size $BATCH_SIZE..." - execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" - done -else - warn "No GPUs found. Skipping miner startup." -fi - -pm2 list - -ohai "Script completed successfully." diff --git a/hparams.json b/hparams.json index 35735d9..3b16da1 100644 --- a/hparams.json +++ b/hparams.json @@ -3,11 +3,11 @@ "compression": 100, "sequence_length": 2048, "tokenizer_name": "togethercomputer/LLaMA-2-7B-32K", - "num_hidden_layers": 16, - "hidden_size": 2048, + "num_hidden_layers": 40, + "hidden_size": 5120, "intermediate_size": 8192, - "num_attention_heads": 8, - "num_key_value_heads": 8, + "num_attention_heads": 40, + "num_key_value_heads": 40, "activation_function": "swiGLU", "max_position_embeddings": 2048, "mixed_precision_param": "bfloat16", diff --git a/miner.py b/miner.py index df79ae7..6bac906 100644 --- a/miner.py +++ b/miner.py @@ -27,19 +27,20 @@ import asyncio import argparse import threading -import traceback from tqdm import tqdm import bittensor as bt -from typing import List import torch.optim as optim from dotenv import dotenv_values from transformers import LlamaForCausalLM from torch.optim.lr_scheduler import CosineAnnealingLR +import torch.distributed as dist +from transformers.models.llama.modeling_llama import LlamaDecoderLayer # Import local files. -from common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.common import * +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader +from boltz.fsdp import * # GPU optimizations. torch.backends.cudnn.benchmark = True @@ -106,14 +107,30 @@ def __init__(self): except: pass wandb.init(project=self.config.project, resume='allow', name=f'M{self.uid}', config=self.config) + # Initialize distributed training + self.local_rank, self.global_rank, self.world_size, self.device = initialize_distributed_training( + device_config=self.config.device) + # Init model. logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) self.hparams = load_hparams() torch.manual_seed(42); np.random.seed(42); random.seed(42) self.model = LlamaForCausalLM(config=self.hparams.model_config) - # self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') - self.model.to(self.config.device) + + # Wrap the model with Fully Sharded Data Parallel (FSDP) if distributed training is initialized + self.model = wrap_model_with_fsdp( + model=self.model, + device_id=self.local_rank, + transformer_layer_names=[LlamaDecoderLayer], + is_distributed=dist.is_initialized(), + device=self.device + ) + init_device = "cuda" + self.model.to_empty(device=init_device) + self.model.init_weights() self.model.train() + + self.model_parts = [self.model] self.optimizer = optim.AdamW( self.model.parameters(), lr=self.hparams.learning_rate, # Peak learning rate @@ -125,7 +142,8 @@ def __init__(self): self.optimizer, T_max=self.hparams.cosine_epoch_length, eta_min=self.hparams.eta_min, last_epoch=-1 ) - + # Synchronize all processes + dist.barrier() # Init buckets. self.buckets = [] for uid in self.metagraph.uids: @@ -166,6 +184,7 @@ async def run(self): self.listener = threading.Thread(target=self.block_listener, args=(self.loop,), daemon=True).start() # Optionally sync the model state by pulling model states from the history. + dist.barrier() if self.config.sync_state: history_windows = [ self.current_window - i for i in range (self.hparams.max_history) ] state_slices = await download_slices_for_buckets_and_windows( @@ -193,6 +212,7 @@ async def run(self): window = self.current_window # Run for non-baseline miners. + dist.barrier() if not self.config.baseline: st = T() state_slices = await download_slices_for_buckets_and_windows( @@ -250,7 +270,7 @@ async def run(self): total_steps += 1 if random.random() < self.sample_rate and not exhuasted_window: full_steps += 1 - input_ids = torch.tensor(batch, dtype=torch.long).to(self.model.device) + input_ids = torch.tensor(batch, dtype=torch.long).to(self.device) labels = input_ids.clone() labels = torch.where(labels == self.hparams.tokenizer.pad_token_id, -100, labels) with torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16): # Enable autocasting @@ -273,7 +293,7 @@ async def run(self): logger.info(f"{P(window, train_duration)} \tLoss: [tan]{step_loss}[tan]") if exhuasted_window: self.sample_rate = max(0.0001, self.sample_rate * 0.95) else: self.sample_rate = min(1, self.sample_rate * 1.05) - + dist.barrier() # Run for non-baseline nodes. if not self.config.baseline: # Upload the delta for the previous window. @@ -338,11 +358,14 @@ async def run(self): f"learning_rate": self.scheduler.get_last_lr()[0] }) - # Catch keyboard interrrupt. + # Catch keyboard interrupt. except KeyboardInterrupt: logger.info("Training interrupted by user. Stopping the run.") self.stop_event.set() await self.update_task + if dist.is_initialized(): + dist.destroy_process_group() + logger.info("Destroyed process group.") sys.exit(0) # Catch unknown. @@ -381,4 +404,10 @@ def handler(event, _u, _s): time.sleep(1) if __name__ == "__main__": - asyncio.run(Miner().run()) + try: + asyncio.run(Miner().run()) + finally: + # Ensure process group is destroyed on exit + if dist.is_initialized(): + dist.destroy_process_group() + logger.info("Destroyed process group on exit.") diff --git a/run.sh b/scripts/run.sh similarity index 100% rename from run.sh rename to scripts/run.sh diff --git a/start.sh b/scripts/start.sh similarity index 100% rename from start.sh rename to scripts/start.sh diff --git a/scripts/start_distributed.sh b/scripts/start_distributed.sh new file mode 100644 index 0000000..46764e5 --- /dev/null +++ b/scripts/start_distributed.sh @@ -0,0 +1,43 @@ +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# Close down all previous processes and restart them. + +pm2 sendSignal SIGINT all +pm2 delete all +# Delete items from bucket +BUCKET=${1:-decis} +PROJECT=${2:-aesop} +# python3 tools/clean.py --bucket $BUCKET + +# Number of GPUs to use. +NGPU=${NGPU:-3} +# The master port for distributed training. +MASTER_PORT=${MASTER_PORT:-29500} +# Which rank should be logged +LOG_RANK=${LOG_RANK:-0} +# Uncomment for debugging. +# export TORCHELASTIC_ERROR_FILE=error.log +# export TORCHELASTIC_DEBUG=1 +# export PYTHONFAULTHANDLER=1 + +# Start the miner using torchrun with distributed training. +pm2 start "torchrun --nproc_per_node=${NGPU} \ + --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank \ + --tee 3 miner.py -- --actual_batch_size 6 --wallet.name Bistro \ +--wallet.hotkey M111 --bucket $BUCKET --device cuda:1 --use_wandb --project $PROJECT --debug" --name M1 --interpreter none diff --git a/tests/eval.py b/tests/eval.py index 3d4a533..e12191d 100644 --- a/tests/eval.py +++ b/tests/eval.py @@ -28,7 +28,7 @@ import tempfile import traceback import bittensor as bt -from hparams import load_hparams +from boltz.hparams import load_hparams from types import SimpleNamespace from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM diff --git a/tests/legacy_miner.py b/tests/legacy_miner.py index d9103aa..3e33a6a 100644 --- a/tests/legacy_miner.py +++ b/tests/legacy_miner.py @@ -37,9 +37,9 @@ from torch.optim.lr_scheduler import CosineAnnealingLR # Import local files. -from common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.common import * +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader # GPU optimizations. torch.backends.cudnn.benchmark = True diff --git a/tests/legacy_validator.py b/tests/legacy_validator.py index 6a753f4..8aecec1 100644 --- a/tests/legacy_validator.py +++ b/tests/legacy_validator.py @@ -36,9 +36,9 @@ from torch.optim.lr_scheduler import CosineAnnealingLR # Import local files. -from common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.common import * +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader # GPU optimizations. torch.backends.cudnn.benchmark = True diff --git a/tests/slices.py b/tests/slices.py new file mode 100644 index 0000000..e2bb6b8 --- /dev/null +++ b/tests/slices.py @@ -0,0 +1,489 @@ +import os +import socket +import sys +from typing import Dict, List +from filelock import FileLock +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.multiprocessing as mp +import asyncio +from torch.distributed.fsdp import StateDictType, FullyShardedDataParallel as FSDP +import logging +from torch.distributed import DeviceMesh +import hashlib +import asyncio +import logging +import tempfile +import aiofiles +from dotenv import dotenv_values +import os +import botocore.config +from aiobotocore.session import get_session +import numpy as np +from torch.distributed._shard.sharded_tensor import ShardedTensor + +env_config = {**dotenv_values(".env"), **os.environ} +AWS_ACCESS_KEY_ID = env_config.get("AWS_ACCESS_KEY_ID") +AWS_SECRET_ACCESS_KEY = env_config.get("AWS_SECRET_ACCESS_KEY") +# Configure the S3 client +client_config = botocore.config.Config( + max_pool_connections=256, +) + +class SimpleTransformer(nn.Module): + def __init__(self, d_model=512, nhead=8, num_layers=6): + super().__init__() + # Create layers as a ModuleDict with numeric keys + self.layers = nn.ModuleDict( + { + str(i): nn.TransformerEncoderLayer(d_model, nhead, batch_first=True) + for i in range(num_layers) + } + ) + self.fc = nn.Linear(d_model, d_model) + + def forward(self, x): + # Apply transformer layers sequentially + for layer_idx in sorted(self.layers.keys(), key=int): # Sort numerically + x = self.layers[layer_idx](x) + return self.fc(x) + +async def upload_slice_for_window( + bucket: str, + model: nn.Module, + window: int, + wallet, + seed: str, + compression: int, + logger: logging.Logger, +) -> str: + """ + Generates and saves a sliced version of the model's parameters to a local file. + + Args: + bucket (str): The name of the S3 bucket (not used for local testing). + model (nn.Module): The FSDP-wrapped model. + window (int): The current window number. + wallet: The wallet object containing the hotkey address. + seed (str): The seed for index generation. + compression (int): The compression factor. + logger (logging.Logger): Logger for debug output. + + Returns: + str: The filename of the saved slice. + """ + rank = dist.get_rank() + logger.debug(f"Rank {rank}: Saving slice to local file") + + device = torch.device("cuda", rank) + + # Include the rank in the filename + filename: str = f"slice-{window}-rank{rank}-{wallet.hotkey.ss58_address}.pt" + logger.debug(f"Rank {rank}: Filename for slice: {filename}") + + indices: Dict[str, torch.LongTensor] = await get_indices_for_window( + model=model, seed=seed, compression=compression, logger=logger + ) + + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + model_state_dict = model.state_dict() + + sliced_state_dict: Dict[str, torch.Tensor] = {} + for name in model_state_dict.keys(): + if name in indices: + idx = indices[name].to(device) + + # Get local shard of the parameter + local_shards = model_state_dict[name].local_shards() + if not local_shards: + logger.warning( + f"Rank {rank}: No local shards for parameter '{name}'. Skipping." + ) + continue + + local_shard = local_shards[0].tensor.to(device) + local_shard_flat = local_shard.contiguous().view(-1) + sliced_param = local_shard_flat[idx] + sliced_state_dict[name] = sliced_param.cpu() + logger.debug( + f"Rank {rank}: Sliced parameter '{name}' with shape {sliced_param.shape}" + ) + else: + logger.debug(f"Rank {rank}: Parameter '{name}' not in indices. Skipping.") + + # Save the sliced state dict to a local file + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + torch.save(sliced_state_dict, file_path) + logger.debug(f"Rank {rank}: Saved sliced state dict to {file_path}") + + return file_path + + +async def get_indices_for_window( + model: torch.nn.Module, seed: str, compression: int, logger: logging.Logger +) -> Dict[str, torch.LongTensor]: + """ + Computes the indices for the given window and compression factor, + ensuring that the compression is applied correctly across all ranks. + + Args: + model (torch.nn.Module): The FSDP-wrapped PyTorch model. + seed (str): The window seed identifier. + compression (int): The compression factor. + logger (logging.Logger): Logger for debug statements. + + Returns: + Dict[str, torch.LongTensor]: A mapping from parameter names to local indices tensors. + """ + rank = dist.get_rank() + logger.debug( + f"Rank {rank}: Computing indices for window seed '{seed}' with compression '{compression}'" + ) + + result = {} + + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + state_dict = model.state_dict() + + for name, param in state_dict.items(): + logger.debug(f"Rank {rank}: Processing parameter: {name}") + + # Check if parameter is a ShardedTensor + if not isinstance(param, ShardedTensor): + logger.warning(f"Parameter '{name}' is not a ShardedTensor. Skipping.") + continue + + # Get total size of the parameter from metadata + param_metadata = param.metadata() # Call the metadata function + global_size = param_metadata.size # Global size of the parameter + + # Compute the total number of elements + total_param_size = 1 + for dim_size in global_size: + total_param_size *= dim_size + + if total_param_size <= 0: + logger.warning(f"Parameter '{name}' has no elements. Skipping.") + continue + + # Compute the total number of indices to select + num_indices = max(1, total_param_size // compression) + num_indices = min(num_indices, total_param_size) + + # Generate the same global indices on all ranks + seed_str = f"{seed}_{name}" + seed_int = int(hashlib.md5(seed_str.encode("utf-8")).hexdigest(), 16) % (2**32) + rng = np.random.default_rng(seed_int) + + global_indices = rng.choice(total_param_size, size=num_indices, replace=False) + global_indices.sort() + + # Get local shard metadata + local_shards = param.local_shards() + if not local_shards: + logger.warning( + f"Rank {rank}: No local shards for parameter '{name}'. Skipping." + ) + continue + + local_shard_metadata = local_shards[0].metadata + shard_offsets = local_shard_metadata.shard_offsets + shard_sizes = local_shard_metadata.shard_sizes + + # Assuming the parameter is flattened (1D), adjust the indices accordingly + shard_start_idx = shard_offsets[0] + shard_end_idx = shard_start_idx + shard_sizes[0] + + # Find indices that are within the local shard + mask = (global_indices >= shard_start_idx) & (global_indices < shard_end_idx) + local_indices = global_indices[mask] - shard_start_idx + + if local_indices.size == 0: + logger.debug( + f"Rank {rank}: No indices for parameter '{name}' in local shard." + ) + continue + + indices_tensor = torch.from_numpy(local_indices).long() + result[name] = indices_tensor + + logger.debug( + f"Rank {rank}: Generated {len(indices_tensor)} local indices for parameter '{name}'" + ) + + if not result: + logger.warning( + f"Rank {rank}: No indices generated. Slice will not be uploaded." + ) + return {} + + return result + + +async def apply_slices_to_model( + model: torch.nn.Module, + window: int, + seed: str, + compression: int, + logger: logging.Logger, + key: str = "slice", +) -> List[str]: + """ + Applies slices from a specific window to the given FSDP model. + + Args: + model (torch.nn.Module): The FSDP-wrapped model to which the slices will be applied. + window (int): The window identifier. + seed (str): The seed used for generating indices. + compression (int): The compression factor. + logger (logging.Logger): Logger for debug output. + key (str): Key prefix for slice files. + + Returns: + List[str]: A list of all the slice files that were applied. + """ + rank = dist.get_rank() + logger.debug( + f"Rank {rank}: Applying slices for window {window} with seed '{seed}' and compression '{compression}'" + ) + + # Get indices for this rank's parameters + indices_dict = await get_indices_for_window(model, seed, compression, logger) + + # Load slices specific to this rank + slice_files = await load_files_for_window_and_rank( + window=window, rank=rank, key=key, logger=logger + ) + + if not slice_files: + logger.warning( + f"Rank {rank}: No slice files found for window {window} and rank {rank}" + ) + return [] + + device = torch.device("cuda", rank) + logger.debug(f"Rank {rank}: Using device {device}") + + # Dictionaries to accumulate the sum of values and count per parameter + param_sums: Dict[str, torch.Tensor] = {} + slices_per_param: Dict[str, int] = {} + + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + model_state_dict = model.state_dict() + + for name in model_state_dict.keys(): + # Get local shard of the parameter + local_shards = model_state_dict[name].local_shards() + if not local_shards: + logger.warning( + f"Rank {rank}: No local shards for parameter '{name}'. Skipping." + ) + continue + + local_shard = local_shards[0].tensor.to(device) + param_sums[name] = torch.zeros_like(local_shard) + slices_per_param[name] = 0 + + # Iterate over each slice file and compute the sum of values + for file_i in slice_files: + try: + slice_i = torch.load(file_i, map_location=device) + for name in param_sums.keys(): + if name not in indices_dict or name not in slice_i: + continue + values = slice_i[name].to(device) + param_indices = indices_dict[name].to(device) + param_sums[name].view(-1)[param_indices] += values + slices_per_param[name] += 1 + del slice_i + except Exception as e: + logger.exception(f"Rank {rank}: Error applying slice from {file_i}: {e}") + + # Apply the average to the parameters + for name in param_sums.keys(): + if slices_per_param[name] == 0: + continue + param_indices = indices_dict[name].to(device) + avg_param = param_sums[name].view(-1)[param_indices] / slices_per_param[name] + + # Update the local shard of the parameter within no_grad + with torch.no_grad(): + local_shards = model_state_dict[name].local_shards() + local_shard = local_shards[0].tensor.to(device) + local_shard_flat = local_shard.view(-1) + local_shard_flat.index_copy_( + 0, param_indices, avg_param.to(local_shard_flat.dtype) + ) + + # Load the updated state dict back into the model + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + model.load_state_dict(model_state_dict, strict=False) + logger.debug(f"Rank {rank}: Loaded updated state dict into model.") + + return slice_files + + +async def load_files_for_window_and_rank( + window: int, rank: int, logger, key: str = "slice" +) -> List[str]: + """ + Retrieves the paths to downloaded window files for a specific rank from the temporary directory. + + Args: + window (int): The window identifier. + rank (int): The rank identifier. + + Returns: + List[str]: A list of file paths corresponding to the window and rank. + """ + logger.debug( + f"Retrieving files for window {window} and rank {rank} from temporary directory" + ) + temp_dir = tempfile.gettempdir() + window_files = [] + file_pattern = f"{key}-{window}-rank{rank}-" + for filename in os.listdir(temp_dir): + if filename.startswith(file_pattern) and filename.endswith(".pt"): + file_path = os.path.join(temp_dir, filename) + window_files.append(file_path) + logger.debug(f"Found file {filename} for window {window} and rank {rank}") + return window_files +def setup(rank, world_size, master_port): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + logger = logging.getLogger(f"Rank {rank}") + logger.debug(f"Process group initialized. Using GPU {rank}") + + +def setup_logging(rank: int) -> logging.Logger: + """ + Sets up a logger for the given rank. + + Args: + rank (int): The rank of the current process. + + Returns: + logging.Logger: Configured logger for the rank. + """ + logger = logging.getLogger(f"Rank {rank}") + logger.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter( + "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Create File Handler + file_handler = logging.FileHandler(f"log_rank_{rank}.log") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + + # Create Stream Handler (optional, can be removed if not needed) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(formatter) + + # Avoid adding multiple handlers to the logger + if not logger.handlers: + logger.addHandler(file_handler) + logger.addHandler(stream_handler) + + return logger + + +def cleanup(rank): + logger = logging.getLogger(f"Rank {rank}") + if dist.is_initialized(): + dist.destroy_process_group() + logger.debug("Process group destroyed.") + else: + logger.warning("Process group not initialized; skipping cleanup.") + +def run_fsdp(rank, world_size, master_port): + logger = logging.getLogger(f"Rank {rank}") + try: + logger = setup_logging(rank) + logger.info(f"Running on rank {rank}") + setup(rank, world_size, master_port) + + model = SimpleTransformer().to(rank) + # Count and log the number of parameters in the model + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info(f"Total parameters in the model: {total_params:,}") + logger.info(f"Trainable parameters in the model: {trainable_params:,}") + + fsdp_model = FSDP(model) + seed = "test_seed" + compression = 10000 + window = 125 + + class MockWallet: + class hotkey: + ss58_address = f"test_address_{rank}" + + mock_wallet = MockWallet() + + # Generate and save slices + filename = asyncio.run( + upload_slice_for_window( + bucket="decis", # Not used in this context + model=fsdp_model, + window=window, + wallet=mock_wallet, + seed=seed, + compression=compression, + logger=logger, + ) + ) + + logger.debug(f"Rank {rank}: Slice saved to {filename}") + + # Apply slices to the model + slice_files = asyncio.run( + apply_slices_to_model( + model=fsdp_model, + window=window, + seed=seed, + compression=compression, + logger=logger, + ) + ) + + logger.debug(f"Rank {rank}: Applied slices: {slice_files}") + + except Exception as e: + logger.exception(f"An error occurred: {str(e)}") + finally: + cleanup(rank) + + +def get_available_gpus(): + return torch.cuda.device_count() + + +def find_free_port(): + """Finds a free port on the localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +if __name__ == "__main__": + world_size = 4 + available_gpus = get_available_gpus() + if available_gpus < world_size: + print(f"Insufficient GPUs. Available: {available_gpus}, Required: {world_size}") + sys.exit(1) + + # Dynamically find a free port to avoid EADDRINUSE + master_port = find_free_port() + print(f"Starting FSDP with world_size={world_size} on port {master_port}") + + mp.spawn(run_fsdp, args=(world_size, master_port), nprocs=world_size, join=True) diff --git a/validator.py b/validator.py index 29c5b12..e743fd0 100644 --- a/validator.py +++ b/validator.py @@ -34,11 +34,13 @@ from dotenv import dotenv_values from transformers import LlamaForCausalLM from torch.optim.lr_scheduler import CosineAnnealingLR +from transformers.models.llama.modeling_llama import LlamaDecoderLayer # Import local files. -from common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.common import * +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader +from boltz.fsdp import * # GPU optimizations. torch.backends.cudnn.benchmark = True @@ -100,7 +102,10 @@ def __init__(self): if run.name == f'V{self.uid}': logger.info(f'Deleting old run: {run}'); run.delete() except: pass - wandb.init(project=self.config.project, resume='allow', name=f'V{self.uid}', config=self.config) + + # Initialize distributed training + self.local_rank, self.global_rank, self.world_size, self.device = initialize_distributed_training( + device_config=self.config.device) # Init model. logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) @@ -108,7 +113,18 @@ def __init__(self): torch.manual_seed(42); np.random.seed(42); random.seed(42) self.model = LlamaForCausalLM(config=self.hparams.model_config) # self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') - self.model.to(self.config.device) + + # Wrap the model with Fully Sharded Data Parallel (FSDP) if distributed training is initialized + self.model = wrap_model_with_fsdp( + model=self.model, + device_id=self.local_rank, + transformer_layer_names=[LlamaDecoderLayer], + is_distributed=dist.is_initialized(), + device=self.device + ) + init_device = "cuda" + self.model.to_empty(device=init_device) + self.model.init_weights() self.model.eval() # Init buckets.