From 9269cd014d1087176e3153aa5829b161fc740532 Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Fri, 25 Oct 2024 19:15:54 +0400 Subject: [PATCH 1/9] feat: slices test --- tests/slices.py | 493 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 493 insertions(+) create mode 100644 tests/slices.py diff --git a/tests/slices.py b/tests/slices.py new file mode 100644 index 0000000..d32ecf8 --- /dev/null +++ b/tests/slices.py @@ -0,0 +1,493 @@ +import os +import socket +import sys +from typing import Dict, List +from filelock import FileLock +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.multiprocessing as mp +import asyncio +from torch.distributed.fsdp import StateDictType, FullyShardedDataParallel as FSDP +import logging +from torch.distributed import DeviceMesh +import hashlib +import asyncio +import logging +import tempfile +import aiofiles +from dotenv import dotenv_values +import os +import botocore.config +from aiobotocore.session import get_session +import numpy as np +from torch.distributed._shard.sharded_tensor import ShardedTensor + +env_config = {**dotenv_values(".env"), **os.environ} +AWS_ACCESS_KEY_ID = env_config.get("AWS_ACCESS_KEY_ID") +AWS_SECRET_ACCESS_KEY = env_config.get("AWS_SECRET_ACCESS_KEY") +# Configure the S3 client +client_config = botocore.config.Config( + max_pool_connections=256, +) + +class SimpleTransformer(nn.Module): + def __init__(self, d_model=512, nhead=8, num_layers=6): + super().__init__() + # Create layers as a ModuleDict with numeric keys + self.layers = nn.ModuleDict( + { + str(i): nn.TransformerEncoderLayer(d_model, nhead, batch_first=True) + for i in range(num_layers) + } + ) + self.fc = nn.Linear(d_model, d_model) + + def forward(self, x): + # Apply transformer layers sequentially + for layer_idx in sorted(self.layers.keys(), key=int): # Sort numerically + x = self.layers[layer_idx](x) + return self.fc(x) + + +def setup(rank, world_size, master_port): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + logger = logging.getLogger(f"Rank {rank}") + logger.debug(f"Process group initialized. Using GPU {rank}") + + +def setup_logging(rank: int) -> logging.Logger: + """ + Sets up a logger for the given rank. + + Args: + rank (int): The rank of the current process. + + Returns: + logging.Logger: Configured logger for the rank. + """ + logger = logging.getLogger(f"Rank {rank}") + logger.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter( + "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Create File Handler + file_handler = logging.FileHandler(f"log_rank_{rank}.log") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + + # Create Stream Handler (optional, can be removed if not needed) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(formatter) + + # Avoid adding multiple handlers to the logger + if not logger.handlers: + logger.addHandler(file_handler) + logger.addHandler(stream_handler) + + return logger + + +def cleanup(rank): + logger = logging.getLogger(f"Rank {rank}") + if dist.is_initialized(): + dist.destroy_process_group() + logger.debug("Process group destroyed.") + else: + logger.warning("Process group not initialized; skipping cleanup.") + + +async def upload_slice_for_window( + bucket: str, + model: nn.Module, + window: int, + wallet, + seed: str, + compression: int, + logger: logging.Logger, +) -> str: + """ + Generates and saves a sliced version of the model's parameters to a local file. + + Args: + bucket (str): The name of the S3 bucket (not used for local testing). + model (nn.Module): The FSDP-wrapped model. + window (int): The current window number. + wallet: The wallet object containing the hotkey address. + seed (str): The seed for index generation. + compression (int): The compression factor. + logger (logging.Logger): Logger for debug output. + + Returns: + str: The filename of the saved slice. + """ + rank = dist.get_rank() + logger.debug(f"Rank {rank}: Saving slice to local file") + + device = torch.device("cuda", rank) + + # Include the rank in the filename + filename: str = f"slice-{window}-rank{rank}-{wallet.hotkey.ss58_address}.pt" + logger.debug(f"Rank {rank}: Filename for slice: {filename}") + + indices: Dict[str, torch.LongTensor] = await get_indices_for_window( + model=model, seed=seed, compression=compression, logger=logger + ) + + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + model_state_dict = model.state_dict() + + sliced_state_dict: Dict[str, torch.Tensor] = {} + for name in model_state_dict.keys(): + if name in indices: + idx = indices[name].to(device) + + # Get local shard of the parameter + local_shards = model_state_dict[name].local_shards() + if not local_shards: + logger.warning( + f"Rank {rank}: No local shards for parameter '{name}'. Skipping." + ) + continue + + local_shard = local_shards[0].tensor.to(device) + local_shard_flat = local_shard.contiguous().view(-1) + sliced_param = local_shard_flat[idx] + sliced_state_dict[name] = sliced_param.cpu() + logger.debug( + f"Rank {rank}: Sliced parameter '{name}' with shape {sliced_param.shape}" + ) + else: + logger.debug(f"Rank {rank}: Parameter '{name}' not in indices. Skipping.") + + # Save the sliced state dict to a local file + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + torch.save(sliced_state_dict, file_path) + logger.debug(f"Rank {rank}: Saved sliced state dict to {file_path}") + + return file_path + + +async def get_indices_for_window( + model: torch.nn.Module, seed: str, compression: int, logger: logging.Logger +) -> Dict[str, torch.LongTensor]: + """ + Computes the indices for the given window and compression factor, + ensuring that the compression is applied correctly across all ranks. + + Args: + model (torch.nn.Module): The FSDP-wrapped PyTorch model. + seed (str): The window seed identifier. + compression (int): The compression factor. + logger (logging.Logger): Logger for debug statements. + + Returns: + Dict[str, torch.LongTensor]: A mapping from parameter names to local indices tensors. + """ + rank = dist.get_rank() + logger.debug( + f"Rank {rank}: Computing indices for window seed '{seed}' with compression '{compression}'" + ) + + result = {} + + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + state_dict = model.state_dict() + + for name, param in state_dict.items(): + logger.debug(f"Rank {rank}: Processing parameter: {name}") + + # Check if parameter is a ShardedTensor + if not isinstance(param, ShardedTensor): + logger.warning(f"Parameter '{name}' is not a ShardedTensor. Skipping.") + continue + + # Get total size of the parameter from metadata + param_metadata = param.metadata() # Call the metadata function + global_size = param_metadata.size # Global size of the parameter + + # Compute the total number of elements + total_param_size = 1 + for dim_size in global_size: + total_param_size *= dim_size + + if total_param_size <= 0: + logger.warning(f"Parameter '{name}' has no elements. Skipping.") + continue + + # Compute the total number of indices to select + num_indices = max(1, total_param_size // compression) + num_indices = min(num_indices, total_param_size) + + # Generate the same global indices on all ranks + seed_str = f"{seed}_{name}" + seed_int = int(hashlib.md5(seed_str.encode("utf-8")).hexdigest(), 16) % (2**32) + rng = np.random.default_rng(seed_int) + + global_indices = rng.choice(total_param_size, size=num_indices, replace=False) + global_indices.sort() + + # Get local shard metadata + local_shards = param.local_shards() + if not local_shards: + logger.warning( + f"Rank {rank}: No local shards for parameter '{name}'. Skipping." + ) + continue + + local_shard_metadata = local_shards[0].metadata + shard_offsets = local_shard_metadata.shard_offsets + shard_sizes = local_shard_metadata.shard_sizes + + # Assuming the parameter is flattened (1D), adjust the indices accordingly + shard_start_idx = shard_offsets[0] + shard_end_idx = shard_start_idx + shard_sizes[0] + + # Find indices that are within the local shard + mask = (global_indices >= shard_start_idx) & (global_indices < shard_end_idx) + local_indices = global_indices[mask] - shard_start_idx + + if local_indices.size == 0: + logger.debug( + f"Rank {rank}: No indices for parameter '{name}' in local shard." + ) + continue + + indices_tensor = torch.from_numpy(local_indices).long() + result[name] = indices_tensor + + logger.debug( + f"Rank {rank}: Generated {len(indices_tensor)} local indices for parameter '{name}'" + ) + + if not result: + logger.warning( + f"Rank {rank}: No indices generated. Slice will not be uploaded." + ) + return {} + + return result + + +async def apply_slices_to_model( + model: torch.nn.Module, + window: int, + seed: str, + compression: int, + logger: logging.Logger, + key: str = "slice", +) -> List[str]: + """ + Applies slices from a specific window to the given FSDP model. + + Args: + model (torch.nn.Module): The FSDP-wrapped model to which the slices will be applied. + window (int): The window identifier. + seed (str): The seed used for generating indices. + compression (int): The compression factor. + logger (logging.Logger): Logger for debug output. + key (str): Key prefix for slice files. + + Returns: + List[str]: A list of all the slice files that were applied. + """ + rank = dist.get_rank() + logger.debug( + f"Rank {rank}: Applying slices for window {window} with seed '{seed}' and compression '{compression}'" + ) + + # Get indices for this rank's parameters + indices_dict = await get_indices_for_window(model, seed, compression, logger) + + # Load slices specific to this rank + slice_files = await load_files_for_window_and_rank( + window=window, rank=rank, key=key, logger=logger + ) + + if not slice_files: + logger.warning( + f"Rank {rank}: No slice files found for window {window} and rank {rank}" + ) + return [] + + device = torch.device("cuda", rank) + logger.debug(f"Rank {rank}: Using device {device}") + + # Dictionaries to accumulate the sum of values and count per parameter + param_sums: Dict[str, torch.Tensor] = {} + slices_per_param: Dict[str, int] = {} + + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + model_state_dict = model.state_dict() + + for name in model_state_dict.keys(): + # Get local shard of the parameter + local_shards = model_state_dict[name].local_shards() + if not local_shards: + logger.warning( + f"Rank {rank}: No local shards for parameter '{name}'. Skipping." + ) + continue + + local_shard = local_shards[0].tensor.to(device) + param_sums[name] = torch.zeros_like(local_shard) + slices_per_param[name] = 0 + + # Iterate over each slice file and compute the sum of values + for file_i in slice_files: + try: + slice_i = torch.load(file_i, map_location=device) + for name in param_sums.keys(): + if name not in indices_dict or name not in slice_i: + continue + values = slice_i[name].to(device) + param_indices = indices_dict[name].to(device) + param_sums[name].view(-1)[param_indices] += values + slices_per_param[name] += 1 + del slice_i + except Exception as e: + logger.exception(f"Rank {rank}: Error applying slice from {file_i}: {e}") + + # Apply the average to the parameters + for name in param_sums.keys(): + if slices_per_param[name] == 0: + continue + param_indices = indices_dict[name].to(device) + avg_param = param_sums[name].view(-1)[param_indices] / slices_per_param[name] + + # Update the local shard of the parameter within no_grad + with torch.no_grad(): + local_shards = model_state_dict[name].local_shards() + local_shard = local_shards[0].tensor.to(device) + local_shard_flat = local_shard.view(-1) + local_shard_flat.index_copy_( + 0, param_indices, avg_param.to(local_shard_flat.dtype) + ) + + # Load the updated state dict back into the model + with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): + model.load_state_dict(model_state_dict, strict=False) + logger.debug(f"Rank {rank}: Loaded updated state dict into model.") + + return slice_files + + +async def load_files_for_window_and_rank( + window: int, rank: int, logger, key: str = "slice" +) -> List[str]: + """ + Retrieves the paths to downloaded window files for a specific rank from the temporary directory. + + Args: + window (int): The window identifier. + rank (int): The rank identifier. + + Returns: + List[str]: A list of file paths corresponding to the window and rank. + """ + logger.debug( + f"Retrieving files for window {window} and rank {rank} from temporary directory" + ) + temp_dir = tempfile.gettempdir() + window_files = [] + file_pattern = f"{key}-{window}-rank{rank}-" + for filename in os.listdir(temp_dir): + if filename.startswith(file_pattern) and filename.endswith(".pt"): + file_path = os.path.join(temp_dir, filename) + window_files.append(file_path) + logger.debug(f"Found file {filename} for window {window} and rank {rank}") + return window_files + + +def run_fsdp(rank, world_size, master_port): + logger = logging.getLogger(f"Rank {rank}") + try: + logger = setup_logging(rank) + logger.info(f"Running on rank {rank}") + setup(rank, world_size, master_port) + + model = SimpleTransformer().to(rank) + # Count and log the number of parameters in the model + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.info(f"Total parameters in the model: {total_params:,}") + logger.info(f"Trainable parameters in the model: {trainable_params:,}") + + fsdp_model = FSDP(model) + seed = "test_seed" + compression = 10000 + window = 125 + + class MockWallet: + class hotkey: + ss58_address = f"test_address_{rank}" + + mock_wallet = MockWallet() + + # Generate and save slices + filename = asyncio.run( + upload_slice_for_window( + bucket="decis", # Not used in this context + model=fsdp_model, + window=window, + wallet=mock_wallet, + seed=seed, + compression=compression, + logger=logger, + ) + ) + + logger.debug(f"Rank {rank}: Slice saved to {filename}") + + # Apply slices to the model + slice_files = asyncio.run( + apply_slices_to_model( + model=fsdp_model, + window=window, + seed=seed, + compression=compression, + logger=logger, + ) + ) + + logger.debug(f"Rank {rank}: Applied slices: {slice_files}") + + except Exception as e: + logger.exception(f"An error occurred: {str(e)}") + finally: + cleanup(rank) + + +def get_available_gpus(): + return torch.cuda.device_count() + + +def find_free_port(): + """Finds a free port on the localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +if __name__ == "__main__": + world_size = 4 + available_gpus = get_available_gpus() + if available_gpus < world_size: + print(f"Insufficient GPUs. Available: {available_gpus}, Required: {world_size}") + sys.exit(1) + + # Dynamically find a free port to avoid EADDRINUSE + master_port = find_free_port() + print(f"Starting FSDP with world_size={world_size} on port {master_port}") + + mp.spawn(run_fsdp, args=(world_size, master_port), nprocs=world_size, join=True) From 1d6d3364e33bdaf00cb4f73bcc5e2aca674a4035 Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Fri, 25 Oct 2024 19:28:30 +0400 Subject: [PATCH 2/9] chore: refactor --- tests/slices.py | 108 +++++++++++++++++++++++------------------------- 1 file changed, 52 insertions(+), 56 deletions(-) diff --git a/tests/slices.py b/tests/slices.py index d32ecf8..e2bb6b8 100644 --- a/tests/slices.py +++ b/tests/slices.py @@ -49,62 +49,6 @@ def forward(self, x): x = self.layers[layer_idx](x) return self.fc(x) - -def setup(rank, world_size, master_port): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(master_port) - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) - logger = logging.getLogger(f"Rank {rank}") - logger.debug(f"Process group initialized. Using GPU {rank}") - - -def setup_logging(rank: int) -> logging.Logger: - """ - Sets up a logger for the given rank. - - Args: - rank (int): The rank of the current process. - - Returns: - logging.Logger: Configured logger for the rank. - """ - logger = logging.getLogger(f"Rank {rank}") - logger.setLevel(logging.DEBUG) - - # Create formatter - formatter = logging.Formatter( - "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - - # Create File Handler - file_handler = logging.FileHandler(f"log_rank_{rank}.log") - file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter(formatter) - - # Create Stream Handler (optional, can be removed if not needed) - stream_handler = logging.StreamHandler() - stream_handler.setLevel(logging.DEBUG) - stream_handler.setFormatter(formatter) - - # Avoid adding multiple handlers to the logger - if not logger.handlers: - logger.addHandler(file_handler) - logger.addHandler(stream_handler) - - return logger - - -def cleanup(rank): - logger = logging.getLogger(f"Rank {rank}") - if dist.is_initialized(): - dist.destroy_process_group() - logger.debug("Process group destroyed.") - else: - logger.warning("Process group not initialized; skipping cleanup.") - - async def upload_slice_for_window( bucket: str, model: nn.Module, @@ -406,7 +350,59 @@ async def load_files_for_window_and_rank( window_files.append(file_path) logger.debug(f"Found file {filename} for window {window} and rank {rank}") return window_files +def setup(rank, world_size, master_port): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + logger = logging.getLogger(f"Rank {rank}") + logger.debug(f"Process group initialized. Using GPU {rank}") + + +def setup_logging(rank: int) -> logging.Logger: + """ + Sets up a logger for the given rank. + + Args: + rank (int): The rank of the current process. + + Returns: + logging.Logger: Configured logger for the rank. + """ + logger = logging.getLogger(f"Rank {rank}") + logger.setLevel(logging.DEBUG) + + # Create formatter + formatter = logging.Formatter( + "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # Create File Handler + file_handler = logging.FileHandler(f"log_rank_{rank}.log") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + # Create Stream Handler (optional, can be removed if not needed) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.DEBUG) + stream_handler.setFormatter(formatter) + + # Avoid adding multiple handlers to the logger + if not logger.handlers: + logger.addHandler(file_handler) + logger.addHandler(stream_handler) + + return logger + + +def cleanup(rank): + logger = logging.getLogger(f"Rank {rank}") + if dist.is_initialized(): + dist.destroy_process_group() + logger.debug("Process group destroyed.") + else: + logger.warning("Process group not initialized; skipping cleanup.") def run_fsdp(rank, world_size, master_port): logger = logging.getLogger(f"Rank {rank}") From af9e1fde363d0eaafa78ecfd433586843ae8957c Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Mon, 28 Oct 2024 02:18:39 +0400 Subject: [PATCH 3/9] feat: fsdp miner + slices --- boltz/common.py | 631 ++++++++++++++++++++++++++++++++++++++++++++++++ boltz/fsdp.py | 31 +++ miner.py | 94 +++++++- 3 files changed, 743 insertions(+), 13 deletions(-) create mode 100644 boltz/common.py create mode 100644 boltz/fsdp.py diff --git a/boltz/common.py b/boltz/common.py new file mode 100644 index 0000000..eb8690f --- /dev/null +++ b/boltz/common.py @@ -0,0 +1,631 @@ +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import os +import io +import sys +import uuid +import time +import fcntl +import torch +import uvloop +import hashlib +import asyncio +import logging +import tempfile +import aiofiles +import numpy as np +import aiobotocore +import bittensor as bt +import botocore.config +from typing import List, Dict +from dotenv import dotenv_values +from types import SimpleNamespace +from rich.logging import RichHandler +from filelock import FileLock, Timeout +from aiobotocore.session import get_session +from rich.highlighter import NullHighlighter +import torch +import hashlib +import numpy as np +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import FullStateDictConfig, StateDictType +import torch.distributed as dist +from torch import nn +import traceback + +# Configure loguru logger +FORMAT = "%(message)s" +logging.basicConfig( + level=logging.INFO, + format=FORMAT, + datefmt="[%X]", + handlers=[ + RichHandler( + markup=True, + rich_tracebacks=True, + highlighter=NullHighlighter(), + show_level=False, + show_time=False, + show_path=False + ) + ] +) +logger = logging.getLogger("rich") +logger.setLevel(logging.INFO) +def debug(): + logger.setLevel(logging.DEBUG) +def trace(): + logger.setLevel(logging.TRACE) +# Log helper. +def T(): return time.time() +def P( w, d ): return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])" + +# Load environment variables +env_config = {**dotenv_values(".env"), **os.environ} +AWS_ACCESS_KEY_ID = env_config.get('AWS_ACCESS_KEY_ID') +AWS_SECRET_ACCESS_KEY = env_config.get('AWS_SECRET_ACCESS_KEY') + +# Configure the S3 client +client_config = botocore.config.Config( + max_pool_connections=256, +) + +# Set uvloop as the event loop policy +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Define a semaphore to limit concurrent downloads (adjust as needed) +semaphore = asyncio.Semaphore(1000) + +async def get_slices( filename:str, device:str ) -> Dict[str, torch.Tensor]: + # Attempt to acquire the lock with a timeout of 1 second. + lock: FileLock = FileLock(f"{filename}.lock") + with lock.acquire(timeout=5): + pass + return torch.load( + filename, + map_location=torch.device(device), + weights_only = True, + ) + +async def apply_slices_to_model( + model: nn.Module, + window: int, + seed: str, + compression: int, + key: str = 'slice' +) -> List[str]: + """ + Applies slices from a specific window to the given FSDP model. + + Args: + model (torch.nn.Module): The FSDP-wrapped PyTorch model to which the slices will be applied. + window (int): The window identifier. + seed (str): The seed used for generating indices. + compression (int): The compression factor. + key (str): The key used to identify the slices. + + Returns: + List[str]: A list of all the slice files that were applied. + + Example: + slice_files = await apply_slices_to_model( + model=my_fsdp_model, + window=42, + seed="1234", + compression=10, + key='slice', + ) + + Notes: + - This function is adapted to work with FSDP. It ensures that all ranks participate + in collective operations required by FSDP to prevent hangs. + - Exception handling is added to ensure that any errors are caught, and all ranks exit gracefully. + """ + rank = dist.get_rank() + world_size = dist.get_world_size() + logger.debug(f"Rank {rank}: Starting apply_slices_to_model") + + # Get indices associated with the window (all ranks must participate) + try: + indices_dict: Dict[str, torch.LongTensor] = await get_indices_for_window(model, seed, compression) + except Exception as e: + logger.exception(f"Rank {rank}: Failed to get indices: {e}") + sys.exit(1) # Ensure all ranks exit to prevent hangs + + logger.debug(f"Rank {rank}: Retrieved indices for parameters") + + # Load slice files on rank 0 and broadcast the list to all ranks + if rank == 0: + try: + slice_files: List[str] = await load_files_for_window(window=window, key=key) + logger.debug(f"Rank {rank}: Loaded {len(slice_files)} slice files") + except Exception as e: + logger.exception(f"Rank {rank}: Failed to load slice files: {e}") + slice_files = [] + else: + slice_files = [] + + # # Broadcast the slice_files list to all ranks + # slice_files_list = [slice_files] + # dist.broadcast_object_list(slice_files_list, src=0) + # slice_files = slice_files_list[0] + + if not slice_files: + logger.warning(f"Rank {rank}: No slice files to process for window {window}") + return slice_files # Early return, but all ranks have synchronized here + + # Initialize dictionaries to keep track of sums and counts + param_sums: Dict[str, torch.Tensor] = {} + slices_per_param: Dict[str, int] = {} + + # Rank 0 processes the slice files and reconstructs the parameters + if rank == 0: + for name in indices_dict.keys(): + param_sums[name] = torch.zeros(indices_dict[name].numel(), dtype=torch.float32) + slices_per_param[name] = 0 + + # Process each slice file + for file_i in slice_files: + logger.debug(f"Rank {rank}: Processing slice file {file_i}") + try: + slice_i = await get_slices(file_i, 'cpu') # Load slices to CPU + for name in slice_i.keys(): + if name in param_sums: + param_sums[name] += slice_i[name].cpu() + slices_per_param[name] += 1 + except Exception as e: + logger.exception(f"Rank {rank}: Error processing slice file {file_i}: {e}") + continue + + # Average the sums to get the updated parameters + for name in param_sums.keys(): + if slices_per_param[name] > 0: + param_sums[name] /= slices_per_param[name] + else: + logger.warning(f"Rank {rank}: No slices applied for parameter {name}") + + # Broadcast the param_sums and slices_per_param to all ranks + # param_sums_list = [param_sums] + # slices_per_param_list = [slices_per_param] + # dist.broadcast_object_list(param_sums_list, src=0) + # dist.broadcast_object_list(slices_per_param_list, src=0) + # param_sums = param_sums_list[0] + # slices_per_param = slices_per_param_list[0] + + # All ranks participate in updating the model parameters + try: + # Retrieve the full state_dict (all ranks must participate) + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + state_dict: Dict[str, torch.Tensor] = model.state_dict() + + for name, param in state_dict.items(): + if name in indices_dict and name in param_sums and slices_per_param[name] > 0: + indices = indices_dict[name].to('cpu') + updated_values = param_sums[name].to(param.dtype) + # Update the parameter values at the specified indices + param.view(-1)[indices] = updated_values + logger.trace(f"Rank {rank}: Updated parameter {name}") + else: + logger.trace(f"Rank {rank}: No updates applied to parameter {name}") + + # Broadcast the updated state_dict from rank 0 to all other ranks + state_dict_list = [state_dict] + dist.broadcast_object_list(state_dict_list, src=0) + state_dict = state_dict_list[0] + logger.debug(f"Rank {rank}: Received updated state_dict from broadcast") + + # Load the updated state_dict back into the model (all ranks must participate) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + model.load_state_dict(state_dict, strict=False) + logger.debug(f"Rank {rank}: Model parameters updated") + + except Exception as e: + logger.exception(f"Rank {rank}: Failed to update model parameters: {e}") + sys.exit(1) # Ensure all ranks exit + + return slice_files + +async def upload_slice_for_window( + bucket: str, + model: torch.nn.Module, + window: int, + seed: str, + wallet: 'bt.wallet', + compression: int, + key: str = 'slice' +): + """ + Uploads a compressed slice of an FSDP model to a storage bucket. + + Args: + bucket (str): Name of the storage bucket. + model (torch.nn.Module): The FSDP-wrapped PyTorch model to be sliced and uploaded. + window (int): The window identifier. + seed (str): The seed used for generating indices. + wallet (bt.wallet): The wallet object containing the hotkey. + compression (int): The compression factor. + key (str): The key used to identify the slices. + + Returns: + None + + Example Usage: + await upload_slice_for_window( + bucket='my-bucket', + model=my_fsdp_model, + window=42, + seed='1234', + wallet=my_wallet, + compression=10, + key='slice' + ) + + Notes: + - This function ensures that all ranks participate in necessary collective operations with FSDP. + - Only Rank 0 performs the actual upload to the storage bucket. + - All ranks must participate in collective operations to prevent hangs. + """ + rank = dist.get_rank() + logger.debug(f"Rank {rank}: Starting upload_slice_for_window") + + # Generate the filename based on the window and wallet hotkey + filename = f'{key}-{window}-{wallet.hotkey.ss58_address}.pt' + logger.debug(f"Rank {rank}: Filename for slice: {filename}") + + try: + # Get indices for slicing (all ranks must participate) + indices = await get_indices_for_window(model, seed, compression) + logger.debug(f"Rank {rank}: Retrieved indices for slicing") + + # Retrieve the full state_dict (all ranks must participate) + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + state_dict = model.state_dict() + logger.debug(f"Rank {rank}: Retrieved state_dict for slicing") + + # Prepare the sliced state_dict + slice_dict = {} + for name, param in state_dict.items(): + if name in indices: + param_indices = indices[name].to('cpu') + sliced_param = param.view(-1)[param_indices].cpu() + slice_dict[name] = sliced_param + logger.trace(f"Rank {rank}: Sliced parameter {name}") + else: + logger.trace(f"Rank {rank}: No indices for parameter {name}; skipping") + except Exception as e: + logger.exception(f"Rank {rank}: Error during slicing: {e}") + sys.exit(1) # Ensure all ranks exit to prevent hangs + + # Only Rank 0 saves and uploads the slice + if rank == 0: + try: + # Save the slice_dict to a temporary file + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + torch.save(slice_dict, temp_file) + temp_file_name = temp_file.name + logger.debug(f"Rank {rank}: Saved slice to temporary file {temp_file_name}") + + # Initialize S3 client + session = get_session() + async with session.create_client( + 's3', + region_name='us-east-1', # Replace with your region + config=client_config, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY + ) as s3_client: + try: + # Upload the file to S3 + with open(temp_file_name, 'rb') as f: + await s3_client.put_object(Bucket=bucket, Key=filename, Body=f) + logger.debug(f"Rank {rank}: Uploaded slice to bucket {bucket} with key {filename}") + + # Optionally set object ACL to public-read + await s3_client.put_object_acl(Bucket=bucket, Key=filename, ACL='public-read') + logger.debug(f"Rank {rank}: Set object ACL to public-read for {filename}") + except Exception as e: + logger.exception(f"Rank {rank}: Failed to upload slice to storage: {e}") + finally: + # Clean up the temporary file + os.remove(temp_file_name) + logger.debug(f"Rank {rank}: Removed temporary file {temp_file_name}") + except Exception as e: + logger.exception(f"Rank {rank}: Error during saving or uploading slice: {e}") + sys.exit(1) # Ensure all ranks exit to prevent hangs + else: + logger.debug(f"Rank {rank}: Slice preparation complete. Waiting for Rank 0 to upload.") + + # Synchronize all ranks to ensure upload is completed before proceeding + dist.barrier() + logger.debug(f"Rank {rank}: Completed upload_slice_for_window") + +async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: int) -> Dict[str, torch.LongTensor]: + """ + Computes the indices for the given window and compression factor. + + Args: + model (torch.nn.Module): The PyTorch model. + seed (str): The window seed identifier. + compression (int): The compression factor. + + Returns: + Dict[str, torch.LongTensor]: A dictionary mapping parameter names to index tensors. + """ + logger.debug(f"Starting get_indices_for_window with seed={seed}, compression={compression}") + result = {} + + # Seed the random number generator with the seed + seed_int = int(hashlib.md5(str(seed).encode('utf-8')).hexdigest(), 16) % (2**32) + logger.trace(f"Converted seed '{seed}' to integer: {seed_int}") + rng = np.random.default_rng(seed_int) + logger.trace(f"Initialized random number generator with seed: {seed_int}") + + # Retrieve the full state dict to get parameter shapes + logger.trace("Retrieving full state dict") + dist.barrier() + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): + state_dict = model.state_dict() + logger.trace(f"Retrieved state dict with {len(state_dict)} parameters") + + # For each parameter, compute the indices + for name, param in state_dict.items(): + logger.trace(f"Processing parameter: {name}") + numel = param.numel() + logger.trace(f"Parameter {name} has {numel} elements") + num_indices = max(1, int(numel // compression)) + logger.trace(f"Selecting {num_indices} indices for parameter {name}") + indices = rng.choice(numel, size=num_indices, replace=False) + logger.trace(f"Generated indices for {name}: min={indices.min()}, max={indices.max()}, shape={indices.shape}") + result[name] = torch.from_numpy(indices).long().cpu() + logger.trace(f"Converted indices for {name} to torch.LongTensor on CPU") + + logger.trace(f"Finished get_indices_for_window, returning dict with {len(result)} entries") + return result + + +async def download_file(s3_client, bucket: str, filename: str) -> str: + """ + Downloads a file from S3, using parallel downloads for large files. + + Args: + s3_client: The S3 client. + bucket (str): Name of the S3 bucket. + filename (str): The S3 object key (filename). + + Returns: + str: The path to the downloaded file in the temporary directory. + """ + async with semaphore: + temp_file = os.path.join(tempfile.gettempdir(), filename) + # Check if the file exists. + if os.path.exists(temp_file): + logger.debug(f"File {temp_file} already exists, skipping download.") + return temp_file + lock_file = f"{temp_file}.lock" + lock = FileLock(lock_file) + try: + # Try to acquire both locks with a timeout + with lock.acquire(timeout=1): + # Proceed to download the file + logger.debug(f"Downloading file {filename} to {temp_file}") + head_response = await s3_client.head_object(Bucket=bucket, Key=filename) + object_size = head_response['ContentLength'] + CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB + + response = await s3_client.get_object(Bucket=bucket, Key=filename) + async with aiofiles.open(temp_file, 'wb') as outfile: + while True: + chunk = await response['Body'].read(CHUNK_SIZE) + if not chunk: + break + await outfile.write(chunk) + + logger.debug(f"Successfully downloaded file {filename} to {temp_file}") + return temp_file + + except Timeout: + logger.error(f"Timeout occurred while trying to acquire lock on {lock_file}") + return None + except Exception as e: + logger.exception(f"Failed to download file {filename} from bucket {bucket}: {e}") + return None + finally: + # The lock is automatically released when exiting the 'with' block + pass + +async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window: int): + """ + Handles downloading a single file from S3. + + Args: + s3_client: The S3 client. + bucket (str): Name of the S3 bucket. + filename (str): The S3 object key (filename). + hotkey (str): The hotkey identifier. + window (int): The window identifier. + + Returns: + SimpleNamespace: An object containing file metadata and the path to the downloaded file. + """ + logger.debug(f"Handling file {filename} for window {window} and hotkey {hotkey}") + temp_file = await download_file(s3_client, bucket, filename) + if temp_file: + return SimpleNamespace(bucket=bucket, hotkey=hotkey, filename=filename, window=window, temp_file=temp_file) + return None + +async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = 'slice'): + """ + Processes an S3 bucket to download files matching the given windows. + + Args: + s3_client: The S3 client. + bucket (str): Name of the S3 bucket. + windows (List[int]): A list of window identifiers. + + Returns: + List[SimpleNamespace]: A list of file metadata and paths for downloaded files. + """ + logger.debug(f"Processing bucket {bucket} for window {windows}") + files = [] + paginator = s3_client.get_paginator('list_objects_v2') + + for window in windows: + prefix = f'{key}-{window}' + logger.debug(f"Listing objects with prefix {prefix}") + async for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + logger.trace(f"Processing page for prefix {prefix}") + if 'Contents' not in page: + logger.trace(f"No contents found for prefix {prefix}") + continue + download_tasks = [] + for obj in page.get('Contents', []): + filename = obj['Key'] + logger.trace(f"Processing object with key {filename}") + try: + parts = filename.split('-') + slice_window = int(parts[1]) + slice_hotkey = parts[2].split('.')[0] + logger.trace(f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}") + if slice_window == window: + download_tasks.append(handle_file(s3_client, bucket, filename, slice_hotkey, slice_window)) + except Exception: + logger.exception(f"Error processing filename {filename}") + continue + # Download the files concurrently + results = await asyncio.gather(*download_tasks) + files.extend([res for res in results if res]) + logger.trace(f"Completed processing page for prefix {prefix}") + logger.trace(f"Completed processing bucket {bucket} for windows {windows}") + return files + +async def download_slices_for_buckets_and_windows(buckets: List[str], windows: List[int], key:str = 'slice') -> Dict[int, List[SimpleNamespace]]: + """ + Downloads files from multiple S3 buckets for the given windows. + + Args: + buckets (List[str]): A list of S3 bucket names. + windows (List[int]): A list of window identifiers. + + Returns: + Dict[int, List[SimpleNamespace]]: A dictionary mapping windows to lists of file metadata and paths. + """ + logger.debug(f"Downloading files for buckets {set(buckets)} and windows {windows}") + session = get_session() + async with session.create_client( + 's3', + region_name='us-east-1', + config=client_config, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY + ) as s3_client: + tasks = [] + for bucket in set(buckets): + if not bucket: + continue + tasks.append(process_bucket(s3_client, bucket, windows, key)) + results = await asyncio.gather(*tasks) + # Flatten the list of lists + files = [item for sublist in results for item in sublist] + + # Create a dictionary with windows as keys and list of files as values + windows_dict = {} + for file in files: + window = file.window + if window not in windows_dict: + windows_dict[window] = [] + windows_dict[window].append(file) + + logger.debug(f"Downloaded all files grouped by windows: {windows}") + return windows_dict + +async def load_files_for_window(window: int, key: str = 'slice') -> List[str]: + """ + Retrieves the paths to downloaded window files from the temporary directory. + + Args: + window (int): The window identifier. + + Returns: + List[str]: A list of file paths corresponding to the window. + """ + logger.debug(f"Retrieving files for window {window} from temporary directory") + temp_dir = tempfile.gettempdir() + window_files = [] + for filename in os.listdir(temp_dir): + if filename.startswith(f"{key}-{window}-") and filename.endswith(".pt"): + window_files.append(os.path.join(temp_dir, filename)) + logger.debug(f"Found file {filename} for window {window}") + return window_files + +async def delete_files_before_window(window_max: int, key:str = 'slice'): + """ + Deletes all files on the local machine which have a window id before a specific value window_max. + + Args: + window_max (int): The maximum window id. Files with window ids less than this value will be deleted. + """ + logger.debug(f"Deleting files with window id before {window_max}") + temp_dir = tempfile.gettempdir() + for filename in os.listdir(temp_dir): + if filename.startswith(f"{key}-") and ( filename.endswith(".pt") or filename.endswith(".lock") ): + try: + parts = filename.split('-') + window_id = int(parts[1]) + if window_id < window_max: + if os.path.exists(filename): + os.remove(filename) + logger.debug(f"Deleted file {filename}") + except Exception as e: + logger.error(f"Error deleting file {filename}: {e}") + +async def delete_files_from_bucket_before_window(bucket: str, window_max: int, key: str = 'slice'): + """ + Deletes all files in the specified S3 bucket which have a window id before a specific value window_max. + + Args: + bucket (str): The name of the S3 bucket. + window_max (int): The maximum window id. Files with window ids less than this value will be deleted. + """ + logger.debug(f"Deleting files in bucket {bucket} with window id before {window_max}") + session = get_session() + async with session.create_client( + 's3', + region_name='us-east-1', + config=client_config, + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY + ) as s3_client: + try: + response = await s3_client.list_objects_v2(Bucket=bucket) + if 'Contents' in response: + for obj in response['Contents']: + filename = obj['Key'] + if filename.startswith(f"{key}-") and filename.endswith(".pt"): + try: + parts = filename.split('-') + window_id = int(parts[1]) + if window_id < window_max: + await s3_client.delete_object(Bucket=bucket, Key=filename) + logger.debug(f"Deleted file {filename} from bucket {bucket}") + except Exception as e: + logger.error(f"Error deleting file {filename} from bucket {bucket}: {e}") + except Exception as e: + logger.error(f"Error listing objects in bucket {bucket}: {e}") diff --git a/boltz/fsdp.py b/boltz/fsdp.py new file mode 100644 index 0000000..b252ac1 --- /dev/null +++ b/boltz/fsdp.py @@ -0,0 +1,31 @@ +from torch.distributed import DeviceMesh +from torch.distributed.fsdp.wrap import ( + _or_policy, + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) + + +def fsdp_auto_wrap_policy(model, transformer_layer_names): + import functools + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn + ) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=set(transformer_layer_names) + ) + + auto_wrap_policy = functools.partial( + _or_policy, policies=[lambda_policy, transformer_wrap_policy] + ) + return auto_wrap_policy diff --git a/miner.py b/miner.py index df79ae7..1b078f6 100644 --- a/miner.py +++ b/miner.py @@ -27,19 +27,23 @@ import asyncio import argparse import threading -import traceback from tqdm import tqdm import bittensor as bt -from typing import List import torch.optim as optim from dotenv import dotenv_values from transformers import LlamaForCausalLM from torch.optim.lr_scheduler import CosineAnnealingLR +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from functools import partial +from transformers.models.llama.modeling_llama import LlamaDecoderLayer # Import local files. -from common import * +from boltz.common import * from hparams import load_hparams from dataset import DatasetLoader +from boltz.fsdp import fsdp_auto_wrap_policy # GPU optimizations. torch.backends.cudnn.benchmark = True @@ -52,8 +56,8 @@ class Miner: def config(): parser = argparse.ArgumentParser(description='Miner script') parser.add_argument('--project', type=str, default='aesop2', help='Optional wandb project name') - parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') - parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') + parser.add_argument('--netuid', type=int, default=223, help='Bittensor network UID.') + parser.add_argument('--bucket', type=str, default='cont2', help='S3 bucket name') parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') @@ -106,14 +110,66 @@ def __init__(self): except: pass wandb.init(project=self.config.project, resume='allow', name=f'M{self.uid}', config=self.config) + # Initialize distributed training + if torch.cuda.is_available() and "LOCAL_RANK" in os.environ: + # torchrun provides LOCAL_RANK, RANK, and WORLD_SIZE environment variables + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.global_rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + + num_gpus = torch.cuda.device_count() + if self.local_rank >= num_gpus: + raise ValueError(f"Local rank {self.local_rank} exceeds number of available GPUs {num_gpus}.") + + # Set the device for this process + torch.cuda.set_device(self.local_rank) + self.device = torch.device("cuda", self.local_rank) + + # Initialize the process group + dist.init_process_group(backend='nccl') + logger.info(f"Distributed training initialized on rank {self.global_rank} out of {self.world_size} processes.") + else: + # Single process execution + self.local_rank = 0 + self.global_rank = 0 + self.world_size = 1 + self.device = torch.device(self.config.device) + logger.warning("Distributed training is not initialized. Running on a single process.") + + + # Identify if the current process is the master (rank 0). + is_master = self.global_rank == 0 + # Init model. logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) self.hparams = load_hparams() torch.manual_seed(42); np.random.seed(42); random.seed(42) self.model = LlamaForCausalLM(config=self.hparams.model_config) - # self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') - self.model.to(self.config.device) - self.model.train() + + # Wrap the model with FSDP if distributed training is initialized + if dist.is_initialized(): + # Define the transformer layer names to wrap + transformer_layer_names = [LlamaDecoderLayer] + + # Create the custom auto wrap policy + auto_wrap_policy = fsdp_auto_wrap_policy( + model=self.model, + transformer_layer_names=transformer_layer_names + ) + + # Wrap the model with FSDP using the custom auto wrap policy + self.model = FSDP( + self.model, + device_id=self.local_rank, + auto_wrap_policy=auto_wrap_policy, + sharding_strategy=ShardingStrategy.FULL_SHARD, + ) + logger.info(f"Model wrapped with FSDP on device {self.device} using custom auto wrap policy.") + else: + # Move the model to the device for single-process execution + self.model.to(self.device) + logger.info(f"Model moved to device {self.device}.") + self.model.train() self.optimizer = optim.AdamW( self.model.parameters(), lr=self.hparams.learning_rate, # Peak learning rate @@ -125,7 +181,8 @@ def __init__(self): self.optimizer, T_max=self.hparams.cosine_epoch_length, eta_min=self.hparams.eta_min, last_epoch=-1 ) - + # Synchronize all processes + dist.barrier() # Init buckets. self.buckets = [] for uid in self.metagraph.uids: @@ -166,6 +223,7 @@ async def run(self): self.listener = threading.Thread(target=self.block_listener, args=(self.loop,), daemon=True).start() # Optionally sync the model state by pulling model states from the history. + dist.barrier() if self.config.sync_state: history_windows = [ self.current_window - i for i in range (self.hparams.max_history) ] state_slices = await download_slices_for_buckets_and_windows( @@ -193,6 +251,7 @@ async def run(self): window = self.current_window # Run for non-baseline miners. + dist.barrier() if not self.config.baseline: st = T() state_slices = await download_slices_for_buckets_and_windows( @@ -250,7 +309,7 @@ async def run(self): total_steps += 1 if random.random() < self.sample_rate and not exhuasted_window: full_steps += 1 - input_ids = torch.tensor(batch, dtype=torch.long).to(self.model.device) + input_ids = torch.tensor(batch, dtype=torch.long).to(self.device) labels = input_ids.clone() labels = torch.where(labels == self.hparams.tokenizer.pad_token_id, -100, labels) with torch.amp.autocast(device_type=self.model.device.type, dtype=torch.bfloat16): # Enable autocasting @@ -273,7 +332,7 @@ async def run(self): logger.info(f"{P(window, train_duration)} \tLoss: [tan]{step_loss}[tan]") if exhuasted_window: self.sample_rate = max(0.0001, self.sample_rate * 0.95) else: self.sample_rate = min(1, self.sample_rate * 1.05) - + dist.barrier() # Run for non-baseline nodes. if not self.config.baseline: # Upload the delta for the previous window. @@ -338,11 +397,14 @@ async def run(self): f"learning_rate": self.scheduler.get_last_lr()[0] }) - # Catch keyboard interrrupt. + # Catch keyboard interrupt. except KeyboardInterrupt: logger.info("Training interrupted by user. Stopping the run.") self.stop_event.set() await self.update_task + if dist.is_initialized(): + dist.destroy_process_group() + logger.info("Destroyed process group.") sys.exit(0) # Catch unknown. @@ -381,4 +443,10 @@ def handler(event, _u, _s): time.sleep(1) if __name__ == "__main__": - asyncio.run(Miner().run()) + try: + asyncio.run(Miner().run()) + finally: + # Ensure process group is destroyed on exit + if dist.is_initialized(): + dist.destroy_process_group() + logger.info("Destroyed process group on exit.") From e176581b73aa6febae325e250e9c524e4cef262f Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Mon, 28 Oct 2024 02:28:05 +0400 Subject: [PATCH 4/9] refactor: move files to boltz --- boltz/common.py | 258 ++++++++++------ boltz/dataset.py | 483 ++++++++++++++++++++++++++++++ boltz/hparams.py | 87 ++++++ scripts/run.sh | 555 +++++++++++++++++++++++++++++++++++ scripts/start.sh | 36 +++ scripts/start_distributed.sh | 43 +++ 6 files changed, 1367 insertions(+), 95 deletions(-) create mode 100644 boltz/dataset.py create mode 100644 boltz/hparams.py create mode 100755 scripts/run.sh create mode 100755 scripts/start.sh create mode 100644 scripts/start_distributed.sh diff --git a/boltz/common.py b/boltz/common.py index eb8690f..482e60e 100644 --- a/boltz/common.py +++ b/boltz/common.py @@ -17,7 +17,7 @@ import os import io -import sys +import sys import uuid import time import fcntl @@ -50,35 +50,46 @@ # Configure loguru logger FORMAT = "%(message)s" -logging.basicConfig( - level=logging.INFO, - format=FORMAT, - datefmt="[%X]", +logging.basicConfig( + level=logging.INFO, + format=FORMAT, + datefmt="[%X]", handlers=[ RichHandler( - markup=True, - rich_tracebacks=True, + markup=True, + rich_tracebacks=True, highlighter=NullHighlighter(), show_level=False, show_time=False, - show_path=False + show_path=False, ) - ] + ], ) logger = logging.getLogger("rich") logger.setLevel(logging.INFO) + + def debug(): logger.setLevel(logging.DEBUG) + + def trace(): logger.setLevel(logging.TRACE) + + # Log helper. -def T(): return time.time() -def P( w, d ): return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])" +def T(): + return time.time() + + +def P(w, d): + return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])" + # Load environment variables env_config = {**dotenv_values(".env"), **os.environ} -AWS_ACCESS_KEY_ID = env_config.get('AWS_ACCESS_KEY_ID') -AWS_SECRET_ACCESS_KEY = env_config.get('AWS_SECRET_ACCESS_KEY') +AWS_ACCESS_KEY_ID = env_config.get("AWS_ACCESS_KEY_ID") +AWS_SECRET_ACCESS_KEY = env_config.get("AWS_SECRET_ACCESS_KEY") # Configure the S3 client client_config = botocore.config.Config( @@ -91,7 +102,8 @@ def P( w, d ): return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63]) # Define a semaphore to limit concurrent downloads (adjust as needed) semaphore = asyncio.Semaphore(1000) -async def get_slices( filename:str, device:str ) -> Dict[str, torch.Tensor]: + +async def get_slices(filename: str, device: str) -> Dict[str, torch.Tensor]: # Attempt to acquire the lock with a timeout of 1 second. lock: FileLock = FileLock(f"{filename}.lock") with lock.acquire(timeout=5): @@ -99,15 +111,12 @@ async def get_slices( filename:str, device:str ) -> Dict[str, torch.Tensor]: return torch.load( filename, map_location=torch.device(device), - weights_only = True, + weights_only=True, ) + async def apply_slices_to_model( - model: nn.Module, - window: int, - seed: str, - compression: int, - key: str = 'slice' + model: nn.Module, window: int, seed: str, compression: int, key: str = "slice" ) -> List[str]: """ Applies slices from a specific window to the given FSDP model. @@ -142,7 +151,9 @@ async def apply_slices_to_model( # Get indices associated with the window (all ranks must participate) try: - indices_dict: Dict[str, torch.LongTensor] = await get_indices_for_window(model, seed, compression) + indices_dict: Dict[str, torch.LongTensor] = await get_indices_for_window( + model, seed, compression + ) except Exception as e: logger.exception(f"Rank {rank}: Failed to get indices: {e}") sys.exit(1) # Ensure all ranks exit to prevent hangs @@ -160,11 +171,6 @@ async def apply_slices_to_model( else: slice_files = [] - # # Broadcast the slice_files list to all ranks - # slice_files_list = [slice_files] - # dist.broadcast_object_list(slice_files_list, src=0) - # slice_files = slice_files_list[0] - if not slice_files: logger.warning(f"Rank {rank}: No slice files to process for window {window}") return slice_files # Early return, but all ranks have synchronized here @@ -176,20 +182,24 @@ async def apply_slices_to_model( # Rank 0 processes the slice files and reconstructs the parameters if rank == 0: for name in indices_dict.keys(): - param_sums[name] = torch.zeros(indices_dict[name].numel(), dtype=torch.float32) + param_sums[name] = torch.zeros( + indices_dict[name].numel(), dtype=torch.float32 + ) slices_per_param[name] = 0 # Process each slice file for file_i in slice_files: logger.debug(f"Rank {rank}: Processing slice file {file_i}") try: - slice_i = await get_slices(file_i, 'cpu') # Load slices to CPU + slice_i = await get_slices(file_i, "cpu") # Load slices to CPU for name in slice_i.keys(): if name in param_sums: param_sums[name] += slice_i[name].cpu() slices_per_param[name] += 1 except Exception as e: - logger.exception(f"Rank {rank}: Error processing slice file {file_i}: {e}") + logger.exception( + f"Rank {rank}: Error processing slice file {file_i}: {e}" + ) continue # Average the sums to get the updated parameters @@ -199,14 +209,6 @@ async def apply_slices_to_model( else: logger.warning(f"Rank {rank}: No slices applied for parameter {name}") - # Broadcast the param_sums and slices_per_param to all ranks - # param_sums_list = [param_sums] - # slices_per_param_list = [slices_per_param] - # dist.broadcast_object_list(param_sums_list, src=0) - # dist.broadcast_object_list(slices_per_param_list, src=0) - # param_sums = param_sums_list[0] - # slices_per_param = slices_per_param_list[0] - # All ranks participate in updating the model parameters try: # Retrieve the full state_dict (all ranks must participate) @@ -215,8 +217,12 @@ async def apply_slices_to_model( state_dict: Dict[str, torch.Tensor] = model.state_dict() for name, param in state_dict.items(): - if name in indices_dict and name in param_sums and slices_per_param[name] > 0: - indices = indices_dict[name].to('cpu') + if ( + name in indices_dict + and name in param_sums + and slices_per_param[name] > 0 + ): + indices = indices_dict[name].to("cpu") updated_values = param_sums[name].to(param.dtype) # Update the parameter values at the specified indices param.view(-1)[indices] = updated_values @@ -241,14 +247,15 @@ async def apply_slices_to_model( return slice_files + async def upload_slice_for_window( bucket: str, model: torch.nn.Module, window: int, seed: str, - wallet: 'bt.wallet', + wallet, compression: int, - key: str = 'slice' + key: str = "slice", ): """ Uploads a compressed slice of an FSDP model to a storage bucket. @@ -285,7 +292,7 @@ async def upload_slice_for_window( logger.debug(f"Rank {rank}: Starting upload_slice_for_window") # Generate the filename based on the window and wallet hotkey - filename = f'{key}-{window}-{wallet.hotkey.ss58_address}.pt' + filename = f"{key}-{window}-{wallet.hotkey.ss58_address}.pt" logger.debug(f"Rank {rank}: Filename for slice: {filename}") try: @@ -303,7 +310,7 @@ async def upload_slice_for_window( slice_dict = {} for name, param in state_dict.items(): if name in indices: - param_indices = indices[name].to('cpu') + param_indices = indices[name].to("cpu") sliced_param = param.view(-1)[param_indices].cpu() slice_dict[name] = sliced_param logger.trace(f"Rank {rank}: Sliced parameter {name}") @@ -325,38 +332,55 @@ async def upload_slice_for_window( # Initialize S3 client session = get_session() async with session.create_client( - 's3', - region_name='us-east-1', # Replace with your region + "s3", + region_name="us-east-1", # Replace with your region config=client_config, aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, ) as s3_client: try: # Upload the file to S3 - with open(temp_file_name, 'rb') as f: + with open(temp_file_name, "rb") as f: await s3_client.put_object(Bucket=bucket, Key=filename, Body=f) - logger.debug(f"Rank {rank}: Uploaded slice to bucket {bucket} with key {filename}") + logger.debug( + f"Rank {rank}: Uploaded slice to bucket {bucket} with key {filename}" + ) # Optionally set object ACL to public-read - await s3_client.put_object_acl(Bucket=bucket, Key=filename, ACL='public-read') - logger.debug(f"Rank {rank}: Set object ACL to public-read for {filename}") + await s3_client.put_object_acl( + Bucket=bucket, Key=filename, ACL="public-read" + ) + logger.debug( + f"Rank {rank}: Set object ACL to public-read for {filename}" + ) except Exception as e: - logger.exception(f"Rank {rank}: Failed to upload slice to storage: {e}") + logger.exception( + f"Rank {rank}: Failed to upload slice to storage: {e}" + ) finally: # Clean up the temporary file os.remove(temp_file_name) - logger.debug(f"Rank {rank}: Removed temporary file {temp_file_name}") + logger.debug( + f"Rank {rank}: Removed temporary file {temp_file_name}" + ) except Exception as e: - logger.exception(f"Rank {rank}: Error during saving or uploading slice: {e}") + logger.exception( + f"Rank {rank}: Error during saving or uploading slice: {e}" + ) sys.exit(1) # Ensure all ranks exit to prevent hangs else: - logger.debug(f"Rank {rank}: Slice preparation complete. Waiting for Rank 0 to upload.") + logger.debug( + f"Rank {rank}: Slice preparation complete. Waiting for Rank 0 to upload." + ) # Synchronize all ranks to ensure upload is completed before proceeding dist.barrier() logger.debug(f"Rank {rank}: Completed upload_slice_for_window") -async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: int) -> Dict[str, torch.LongTensor]: + +async def get_indices_for_window( + model: torch.nn.Module, seed: str, compression: int +) -> Dict[str, torch.LongTensor]: """ Computes the indices for the given window and compression factor. @@ -368,11 +392,13 @@ async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: Returns: Dict[str, torch.LongTensor]: A dictionary mapping parameter names to index tensors. """ - logger.debug(f"Starting get_indices_for_window with seed={seed}, compression={compression}") + logger.debug( + f"Starting get_indices_for_window with seed={seed}, compression={compression}" + ) result = {} # Seed the random number generator with the seed - seed_int = int(hashlib.md5(str(seed).encode('utf-8')).hexdigest(), 16) % (2**32) + seed_int = int(hashlib.md5(str(seed).encode("utf-8")).hexdigest(), 16) % (2**32) logger.trace(f"Converted seed '{seed}' to integer: {seed_int}") rng = np.random.default_rng(seed_int) logger.trace(f"Initialized random number generator with seed: {seed_int}") @@ -393,11 +419,15 @@ async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: num_indices = max(1, int(numel // compression)) logger.trace(f"Selecting {num_indices} indices for parameter {name}") indices = rng.choice(numel, size=num_indices, replace=False) - logger.trace(f"Generated indices for {name}: min={indices.min()}, max={indices.max()}, shape={indices.shape}") + logger.trace( + f"Generated indices for {name}: min={indices.min()}, max={indices.max()}, shape={indices.shape}" + ) result[name] = torch.from_numpy(indices).long().cpu() logger.trace(f"Converted indices for {name} to torch.LongTensor on CPU") - logger.trace(f"Finished get_indices_for_window, returning dict with {len(result)} entries") + logger.trace( + f"Finished get_indices_for_window, returning dict with {len(result)} entries" + ) return result @@ -427,13 +457,13 @@ async def download_file(s3_client, bucket: str, filename: str) -> str: # Proceed to download the file logger.debug(f"Downloading file {filename} to {temp_file}") head_response = await s3_client.head_object(Bucket=bucket, Key=filename) - object_size = head_response['ContentLength'] + object_size = head_response["ContentLength"] CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB response = await s3_client.get_object(Bucket=bucket, Key=filename) - async with aiofiles.open(temp_file, 'wb') as outfile: + async with aiofiles.open(temp_file, "wb") as outfile: while True: - chunk = await response['Body'].read(CHUNK_SIZE) + chunk = await response["Body"].read(CHUNK_SIZE) if not chunk: break await outfile.write(chunk) @@ -442,15 +472,20 @@ async def download_file(s3_client, bucket: str, filename: str) -> str: return temp_file except Timeout: - logger.error(f"Timeout occurred while trying to acquire lock on {lock_file}") + logger.error( + f"Timeout occurred while trying to acquire lock on {lock_file}" + ) return None except Exception as e: - logger.exception(f"Failed to download file {filename} from bucket {bucket}: {e}") + logger.exception( + f"Failed to download file {filename} from bucket {bucket}: {e}" + ) return None finally: # The lock is automatically released when exiting the 'with' block pass + async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window: int): """ Handles downloading a single file from S3. @@ -468,10 +503,19 @@ async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window logger.debug(f"Handling file {filename} for window {window} and hotkey {hotkey}") temp_file = await download_file(s3_client, bucket, filename) if temp_file: - return SimpleNamespace(bucket=bucket, hotkey=hotkey, filename=filename, window=window, temp_file=temp_file) + return SimpleNamespace( + bucket=bucket, + hotkey=hotkey, + filename=filename, + window=window, + temp_file=temp_file, + ) return None -async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = 'slice'): + +async def process_bucket( + s3_client, bucket: str, windows: List[int], key: str = "slice" +): """ Processes an S3 bucket to download files matching the given windows. @@ -485,27 +529,33 @@ async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = ' """ logger.debug(f"Processing bucket {bucket} for window {windows}") files = [] - paginator = s3_client.get_paginator('list_objects_v2') + paginator = s3_client.get_paginator("list_objects_v2") for window in windows: - prefix = f'{key}-{window}' + prefix = f"{key}-{window}" logger.debug(f"Listing objects with prefix {prefix}") async for page in paginator.paginate(Bucket=bucket, Prefix=prefix): logger.trace(f"Processing page for prefix {prefix}") - if 'Contents' not in page: + if "Contents" not in page: logger.trace(f"No contents found for prefix {prefix}") continue download_tasks = [] - for obj in page.get('Contents', []): - filename = obj['Key'] + for obj in page.get("Contents", []): + filename = obj["Key"] logger.trace(f"Processing object with key {filename}") try: - parts = filename.split('-') + parts = filename.split("-") slice_window = int(parts[1]) - slice_hotkey = parts[2].split('.')[0] - logger.trace(f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}") + slice_hotkey = parts[2].split(".")[0] + logger.trace( + f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}" + ) if slice_window == window: - download_tasks.append(handle_file(s3_client, bucket, filename, slice_hotkey, slice_window)) + download_tasks.append( + handle_file( + s3_client, bucket, filename, slice_hotkey, slice_window + ) + ) except Exception: logger.exception(f"Error processing filename {filename}") continue @@ -516,7 +566,10 @@ async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = ' logger.trace(f"Completed processing bucket {bucket} for windows {windows}") return files -async def download_slices_for_buckets_and_windows(buckets: List[str], windows: List[int], key:str = 'slice') -> Dict[int, List[SimpleNamespace]]: + +async def download_slices_for_buckets_and_windows( + buckets: List[str], windows: List[int], key: str = "slice" +) -> Dict[int, List[SimpleNamespace]]: """ Downloads files from multiple S3 buckets for the given windows. @@ -530,11 +583,11 @@ async def download_slices_for_buckets_and_windows(buckets: List[str], windows: L logger.debug(f"Downloading files for buckets {set(buckets)} and windows {windows}") session = get_session() async with session.create_client( - 's3', - region_name='us-east-1', + "s3", + region_name="us-east-1", config=client_config, aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, ) as s3_client: tasks = [] for bucket in set(buckets): @@ -556,7 +609,8 @@ async def download_slices_for_buckets_and_windows(buckets: List[str], windows: L logger.debug(f"Downloaded all files grouped by windows: {windows}") return windows_dict -async def load_files_for_window(window: int, key: str = 'slice') -> List[str]: + +async def load_files_for_window(window: int, key: str = "slice") -> List[str]: """ Retrieves the paths to downloaded window files from the temporary directory. @@ -575,7 +629,8 @@ async def load_files_for_window(window: int, key: str = 'slice') -> List[str]: logger.debug(f"Found file {filename} for window {window}") return window_files -async def delete_files_before_window(window_max: int, key:str = 'slice'): + +async def delete_files_before_window(window_max: int, key: str = "slice"): """ Deletes all files on the local machine which have a window id before a specific value window_max. @@ -585,9 +640,11 @@ async def delete_files_before_window(window_max: int, key:str = 'slice'): logger.debug(f"Deleting files with window id before {window_max}") temp_dir = tempfile.gettempdir() for filename in os.listdir(temp_dir): - if filename.startswith(f"{key}-") and ( filename.endswith(".pt") or filename.endswith(".lock") ): + if filename.startswith(f"{key}-") and ( + filename.endswith(".pt") or filename.endswith(".lock") + ): try: - parts = filename.split('-') + parts = filename.split("-") window_id = int(parts[1]) if window_id < window_max: if os.path.exists(filename): @@ -596,7 +653,10 @@ async def delete_files_before_window(window_max: int, key:str = 'slice'): except Exception as e: logger.error(f"Error deleting file {filename}: {e}") -async def delete_files_from_bucket_before_window(bucket: str, window_max: int, key: str = 'slice'): + +async def delete_files_from_bucket_before_window( + bucket: str, window_max: int, key: str = "slice" +): """ Deletes all files in the specified S3 bucket which have a window id before a specific value window_max. @@ -604,28 +664,36 @@ async def delete_files_from_bucket_before_window(bucket: str, window_max: int, k bucket (str): The name of the S3 bucket. window_max (int): The maximum window id. Files with window ids less than this value will be deleted. """ - logger.debug(f"Deleting files in bucket {bucket} with window id before {window_max}") + logger.debug( + f"Deleting files in bucket {bucket} with window id before {window_max}" + ) session = get_session() async with session.create_client( - 's3', - region_name='us-east-1', + "s3", + region_name="us-east-1", config=client_config, aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY + aws_secret_access_key=AWS_SECRET_ACCESS_KEY, ) as s3_client: try: response = await s3_client.list_objects_v2(Bucket=bucket) - if 'Contents' in response: - for obj in response['Contents']: - filename = obj['Key'] + if "Contents" in response: + for obj in response["Contents"]: + filename = obj["Key"] if filename.startswith(f"{key}-") and filename.endswith(".pt"): try: - parts = filename.split('-') + parts = filename.split("-") window_id = int(parts[1]) if window_id < window_max: - await s3_client.delete_object(Bucket=bucket, Key=filename) - logger.debug(f"Deleted file {filename} from bucket {bucket}") + await s3_client.delete_object( + Bucket=bucket, Key=filename + ) + logger.debug( + f"Deleted file {filename} from bucket {bucket}" + ) except Exception as e: - logger.error(f"Error deleting file {filename} from bucket {bucket}: {e}") + logger.error( + f"Error deleting file {filename} from bucket {bucket}: {e}" + ) except Exception as e: logger.error(f"Error listing objects in bucket {bucket}: {e}") diff --git a/boltz/dataset.py b/boltz/dataset.py new file mode 100644 index 0000000..26cb18a --- /dev/null +++ b/boltz/dataset.py @@ -0,0 +1,483 @@ +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import time +import typing +import random +import requests +import asyncio +import aiohttp +import numpy as np +from tqdm import tqdm +from transformers import AutoTokenizer +from torch.utils.data import IterableDataset + +class SubsetLoader(IterableDataset): + """ + Base class for data-specific subset loader classes. + + # TODO: Make this class abstract + """ + def __init__( + self, + batch_size=None, + sequence_length=None, + num_pages=None, + tokenizer: AutoTokenizer=None, + pack_samples: bool=False, + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.num_pages = num_pages + self.tokenizer = tokenizer + self.pack_samples = pack_samples + + self.num_rows_per_page = 100 + + # Buffer to hold pages loaded from the api + self.buffer = [] + + # Buffer to hold pages already loaded into a batch + self.used_buffer = [] + + # Buffer to hold padded pages + self.padded_buffer = [] + + self.lock = asyncio.Lock() # For thread-safe operations + + async def fetch_data_for_pages(self, pages): + """ + Set the pages to be used to fill the buffer. Then fetch the page data + to the buffer. + """ + + self.pages = pages + + # Empty the buffer if it is not. + self.buffer = [] + + async with aiohttp.ClientSession() as session: + tasks = [self._fetch_data_for_page(page, session) for page in self.pages] + await asyncio.gather(*tasks) + + async def _fetch_data_for_page(self, page, session): + retry_limit = 10 + attempt = 0 + while attempt < retry_limit: + config_name, page_number, split = page + + # Create the request parameters + params = dict(dataset=self.name, + config=config_name, + split=split, + offset=page_number, + limit=self.num_rows_per_page + ) + + try: + async with session.get(self.rows_base_url, params=params) as response: + response.raise_for_status() + data = await response.json() + + # Prepare the data to append + buffer_to_append = [] + for row in data["rows"]: + content = row["row"]["text"] + input_ids = self.tokenizer(content, truncation=True)["input_ids"] + buffer_to_append.extend(input_ids) + buffer_to_append.append(self.tokenizer.eos_token_id) + + async with self.lock: + self.buffer.extend(buffer_to_append) + self.pages.append((config_name, page_number, split)) + break # Success, exit retry loop + + except aiohttp.ClientResponseError as e: + attempt += 1 + if attempt < retry_limit: + await asyncio.sleep(5) + else: + raise + + def _get_pad_size(self, input_ids): + """ + Get the number of tokens to be padded to the sample to match + the max allowed sequence length. + If sample packing is activated, then return 1 + """ + + if self.pack_samples: + return 1 + + sample_size = len(input_ids) + + remainder = (sample_size % self.sequence_length) + pad_size = (self.sequence_length - remainder) + + # Apply modulo again to guarantee a pad size of 0 if remainder is 0 + pad_size = pad_size % self.sequence_length + + return pad_size + + def _refill_padded_buffer(self): + """ + This methods pulls one page from `self.buffer`, pads it and pushs + it to the `self.padded_buffer`. + """ + + while ( + self.buffer + and len(self.padded_buffer) < self.sequence_length + ): + + input_ids = [] + + # search for EOS token index and cut the buffer at it. + EOS_index = self.buffer.index(self.tokenizer.eos_token_id) + input_ids = self.buffer[:EOS_index+1] + self.buffer =self.buffer[EOS_index+1:] + + self.used_buffer += input_ids + + # Add to padded buffer without the EOS token. + self.padded_buffer += input_ids[:-1] + + # Pad + self.padded_buffer += [self.tokenizer.eos_token_id] * self._get_pad_size(input_ids=input_ids[:-1]) + + def __iter__(self): + self.buffer = self.used_buffer + self.buffer + self.padded_buffer = [] + + # Pad and prepare one page for batching + self._refill_padded_buffer() + + return self + + def __next__(self): + batch = [] + + while len(self.padded_buffer) >= self.sequence_length: + batch.append(self.padded_buffer[: self.sequence_length]) + self.padded_buffer = self.padded_buffer[self.sequence_length :] + self._refill_padded_buffer() + + if len(batch) == self.batch_size: + return np.stack(batch) + + raise StopIteration + + +class DatasetLoader(SubsetLoader): + + name: str = "HuggingFaceFW/fineweb-edu-score-2" + rows_base_url: str = "https://datasets-server.huggingface.co/rows" + size_base_url: str = "https://datasets-server.huggingface.co/size" + + retry_limit: int = 10 # Number of retries + retry_delay: int = 5 # Seconds to wait between retries + num_rows_per_page: int = 100 + + @staticmethod + async def next_pages(offset: int, n_pages: int, seed: str, num_rows_per_page: int = 100): + configs_data = await DatasetLoader.fetch_dataset_configs() + rng = np.random.default_rng(hash(seed) & 0xffffffff) # Create a generator with a seed + rng.bit_generator.advance(offset) # Efficiently skip ahead `n` steps + result = [] + for _ in range(n_pages): + config = rng.choice(list(configs_data.keys())) + choice = rng.integers(0, configs_data[config]['num_rows'] - 1 - num_rows_per_page) + result.append((str(config), int(choice), configs_data[config]['split'])) + return result + + def __init__( + self, + batch_size=None, + sequence_length=None, + num_pages=None, + pages_info=None, + tokenizer: AutoTokenizer = None, + pack_samples: bool = False, + ): + super().__init__(batch_size, + sequence_length, + num_pages, + tokenizer, + pack_samples) + + # Initialize properties + self.configs_data = None + self.pages = [] + self.buffer = [] + self.lock = asyncio.Lock() # For thread-safe operations + + @classmethod + async def create( + cls, + batch_size=None, + sequence_length=None, + num_pages=None, + pages_info=None, + tokenizer: AutoTokenizer = None, + pack_samples: bool = False, + ): + self = cls( + batch_size=batch_size, + sequence_length=sequence_length, + num_pages=num_pages, + tokenizer=tokenizer, + pack_samples=pack_samples + ) + + # Fetch dataset configs asynchronously + self.configs_data = await cls.fetch_dataset_configs() + + if pages_info is not None: + await self._fetch(pages_info) + elif self.num_pages: + await self._fetch_data_to_buffer(self.num_pages) + + return self + + async def _fetch(self, page_info: typing.Tuple[str, int, str]): + self.pages = list(page_info) + num_pages = len(self.pages) + async with aiohttp.ClientSession() as session: + tasks = [self._fetch_data_for_page((config_name, page, split), session) + for (config_name, page, split) in self.pages] + await asyncio.gather(*tasks) + + async def _fetch_data_to_buffer(self, num_pages): + """ + Randomly sample pages and add their data to the buffer. + If a page is inaccessible, another one is sampled. + This method sets the `pages` property. + """ + self.pages = [] + pages_to_fetch = self.get_random_pages(num_pages) + + async with aiohttp.ClientSession() as session: + tasks = [self._fetch_data_for_page(page, session) for page in pages_to_fetch] + await asyncio.gather(*tasks) + + async def fetch_data_to_rows(self, num_pages): + rows = [] + pages_to_fetch = self.get_random_pages(num_pages) + + async with aiohttp.ClientSession() as session: + tasks = [self._fetch_rows_for_page(page, session) for page in pages_to_fetch] + results = await asyncio.gather(*tasks) + for page_rows in results: + rows.extend(page_rows) + + return rows + + + async def _fetch_data_for_page(self, page, session): + """ + Fetches data asynchronously for a single page, processes it without blocking the event loop, + and appends the tokenized data to the buffer. + + Args: + page: A tuple containing the config name, page number, and split. + session: The HTTP session used for making requests. + + Raises: + Exception: If the maximum number of retry attempts is exceeded. + """ + retry_limit = self.retry_limit + attempt = 0 + while attempt < retry_limit: + config_name, page_number, split = page + + # Create the request parameters + params = { + 'dataset': self.name, + 'config': config_name, + 'split': split, + 'offset': page_number, + 'limit': self.num_rows_per_page + } + + try: + # Make an asynchronous HTTP GET request to fetch the data + async with session.get(self.rows_base_url, params=params) as response: + response.raise_for_status() # Raise an exception for HTTP errors + data = await response.json() + + # Prepare the data to append + buffer_to_append = [] + + # Asynchronously process each row without blocking the event loop + tasks = [ + self._tokenize_content(row["row"]["text"]) for row in data["rows"] + ] + + # Gather the tokenized results concurrently + row_input_ids = await asyncio.gather(*tasks) + + # Flatten the list of input IDs and append them to the buffer + for input_ids in row_input_ids: + buffer_to_append.extend(input_ids) + + # Safely append the processed data to the shared buffer + async with self.lock: + self.buffer.extend(buffer_to_append) + self.pages.append((config_name, page_number, split)) + break # Success, exit retry loop + + except aiohttp.ClientResponseError as e: + # Handle HTTP client errors with a retry mechanism + attempt += 1 + if attempt < retry_limit: + await asyncio.sleep(self.retry_delay) # Wait before retrying + else: + raise Exception(f"Maximum retry attempts exceeded for page {page}") from e + + async def _tokenize_content(self, content): + """ + Asynchronously tokenizes a string of content using the tokenizer in a separate thread. + + Args: + content: The text content to be tokenized. + + Returns: + The list of token IDs for the content, including the EOS token. + """ + # Offload the CPU-bound tokenization to a thread executor to prevent blocking the event loop + input_ids = await asyncio.to_thread( + self.tokenizer.encode, content, truncation=True, max_length=self.sequence_length + ) + input_ids.append(self.tokenizer.eos_token_id) + return input_ids + + async def _fetch_rows_for_page(self, page, session): + retry_limit = self.retry_limit + attempt = 0 + while attempt < retry_limit: + config_name, page_number, split = page + + # Create the request parameters + params = dict(dataset=self.name, + config=config_name, + split=split, + offset=page_number, + limit=self.num_rows_per_page + ) + + try: + async with session.get(self.rows_base_url, params=params) as response: + response.raise_for_status() + data = await response.json() + + # Collect the rows + return [row["row"]["text"] for row in data["rows"]] + + except aiohttp.ClientResponseError as e: + attempt += 1 + if attempt < retry_limit: + await asyncio.sleep(self.retry_delay) + else: + raise + + def get_random_pages(self, num_pages): + """ + Randomly sample pages. + A page is a row number of a given split of a given dataset dump. + """ + pages = [] + + for _ in range(num_pages): + # Choose a random config + config_name = random.choice(list(self.configs_data.keys())) + + # Choose a random page (row) + page = random.randint(0, + self.configs_data[config_name]['num_rows'] - 1 - self.num_rows_per_page) + + split = self.configs_data[config_name]['split'] + + pages.append((config_name, page, split)) + + return pages + + def get_page_names(self): + """ + This is a utility function that returns the page names that were used. + Each page as a single string instead of a tuple. + """ + page_names = [] + + if hasattr(self, 'pages'): + page_names = [f'{cfg_name}_{num_rows}_{split}' for + cfg_name, num_rows, split in self.pages] + + return page_names + + @staticmethod + async def fetch_dataset_configs() -> typing.Dict[str, typing.Dict]: + """ + Fetch the different dump names, aka configs, aka samples, of the + dataset. + The returned value is a dictionary with dump names as keys and + a dict of the number of rows and the split as values. + """ + # Request parameters + params = dict( + dataset=DatasetLoader.name + ) + + attempt = 0 + while attempt < DatasetLoader.retry_limit: + try: + async with aiohttp.ClientSession() as session: + async with session.get(DatasetLoader.size_base_url, params=params) as response: + response.raise_for_status() + + data = await response.json() + + # Extract the configs dict + configs_dict = data['size']['splits'] + + # Now create a dict with config names (except 'default') as + # keys, and the number of rows as values + configs_data = {entry['config']: {'num_rows': entry['num_rows'], + 'split': entry['split']} + for entry in configs_dict + if entry['config'] != 'default' + } + + return configs_data + + except aiohttp.ClientResponseError as e: + attempt += 1 + if attempt < DatasetLoader.retry_limit: + await asyncio.sleep(DatasetLoader.retry_delay) + else: + raise + + @staticmethod + async def next_pages_async(offset: int, n_pages: int, seed: str, num_rows_per_page: int = 100): + configs_data = await DatasetLoader.fetch_dataset_configs() + rng = np.random.default_rng(hash(seed) & 0xffffffff) # Create a generator with a seed + rng.bit_generator.advance(offset) # Efficiently skip ahead `n` steps + result = [] + for _ in range(n_pages): + config = rng.choice(list(configs_data.keys())) + choice = rng.integers(0, configs_data[config]['num_rows'] - 1 - num_rows_per_page) + result.append((str(config), int(choice), configs_data[config]['split'])) + return result diff --git a/boltz/hparams.py b/boltz/hparams.py new file mode 100644 index 0000000..744d2e1 --- /dev/null +++ b/boltz/hparams.py @@ -0,0 +1,87 @@ +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import os +import json +import time +import requests +from types import SimpleNamespace +from transformers import AutoTokenizer, LlamaConfig + +from boltz.common import * + +# Cache file path +HPARAMS_FILE = "hparams.json" + +def create_namespace(hparams: dict) -> SimpleNamespace: + """ + Create a SimpleNamespace from the hyperparameters and add model configuration. + + Args: + hparams (dict): Hyperparameters dictionary. + + Returns: + SimpleNamespace: Namespace containing hyperparameters and model configuration. + """ + hparams_ns = SimpleNamespace(**hparams) + + hparams_ns.tokenizer = AutoTokenizer.from_pretrained( + hparams_ns.tokenizer_name, verbose=False, clean_up_tokenization_spaces=True + ) + hparams_ns.tokenizer.pad_token = hparams_ns.tokenizer.eos_token + + hparams_ns.model_config = LlamaConfig( + vocab_size=hparams_ns.tokenizer.vocab_size, + hidden_size=hparams_ns.hidden_size, + num_hidden_layers=hparams_ns.num_hidden_layers, + num_attention_heads=hparams_ns.num_attention_heads, + intermediate_size=hparams_ns.intermediate_size, + num_key_value_heads=hparams_ns.num_key_value_heads, + activation_function=hparams_ns.activation_function, + max_position_embeddings=hparams_ns.max_position_embeddings, + ) + + return hparams_ns + +def load_hparams() -> SimpleNamespace: + """ + Load hyperparameters from a GitHub file, with caching and fallback mechanisms. + + Returns: + SimpleNamespace: A namespace containing the hyperparameters and model configuration. + + Example: + hparams = load_hparams() + print(hparams.hidden_size) + print(hparams.model_config) + """ + github_url = f"https://raw.githubusercontent.com/unconst/cont/master/hparams.json?timestamp={int(time.time())}" + try: + # Attempt to fetch from the GitHub file first + response = requests.get(github_url, timeout=10, headers={'Cache-Control': 'no-cache'}) + response.raise_for_status() + hparams = json.loads(response.text) + logger.debug("Successfully loaded parameters from GitHub.") + except (requests.RequestException, json.JSONDecodeError) as e: + logger.debug(f"Error loading parameters from GitHub: {e}") + logger.debug("Attempting to load from cache...") + with open(HPARAMS_FILE, "r") as f: + hparams = json.load(f) + # Cache the new parameters + with open(HPARAMS_FILE, "w") as f: + json.dump(hparams, f, indent=4) + return create_namespace(hparams) diff --git a/scripts/run.sh b/scripts/run.sh new file mode 100755 index 0000000..9db4a4d --- /dev/null +++ b/scripts/run.sh @@ -0,0 +1,555 @@ +#!/usr/bin/env bash + +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +set -euo pipefail + +# Initialize default values +DEBUG=false +PROJECT="aesop" +AWS_ACCESS_KEY_ID="" +AWS_SECRET_ACCESS_KEY="" +BUCKET="" + +# Function to display help message +display_help() { + cat << EOF +Usage: $0 [options] + +Options: + --debug Enable debug mode + --project Set the project name (default: aesop) + --aws-access-key-id Set AWS Access Key ID + --aws-secret-access-key Set AWS Secret Access Key + --bucket Set the S3 bucket name + -h, --help Display this help message + +Description: + Installs and runs a Boltzmann miner on your GPU. +EOF +} + +# Parse command-line arguments +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + --debug) + DEBUG=true + shift + ;; + --project) + PROJECT="$2" + shift 2 + ;; + --aws-access-key-id) + AWS_ACCESS_KEY_ID="$2" + shift 2 + ;; + --aws-secret-access-key) + AWS_SECRET_ACCESS_KEY="$2" + shift 2 + ;; + --bucket) + BUCKET="$2" + shift 2 + ;; + -h|--help|-help|--h) + display_help + exit 0 + ;; + *) + echo "Unknown option: $1" + display_help + exit 1 + ;; + esac +done + +# Set up colors and styles +if [[ -t 1 ]]; then + tty_escape() { printf "\033[%sm" "$1"; } +else + tty_escape() { :; } +fi +tty_mkbold() { tty_escape "1;$1"; } +tty_blue="$(tty_mkbold 34)" +tty_red="$(tty_mkbold 31)" +tty_green="$(tty_mkbold 32)" +tty_yellow="$(tty_mkbold 33)" +tty_bold="$(tty_mkbold 39)" +tty_reset="$(tty_escape 0)" + +# Logging functions +ohai() { + printf "${tty_blue}==>${tty_bold} %s${tty_reset}\n" "$*" +} + +pdone() { + printf " ${tty_green}[✔]${tty_bold} %s${tty_reset}\n" "$*" +} + +info() { + printf "${tty_green}%s${tty_reset}\n" "$*" +} + +warn() { + printf "${tty_yellow}Warning${tty_reset}: %s\n" "$*" >&2 +} + +error() { + printf "${tty_red}Error${tty_reset}: %s\n" "$*" >&2 +} + +abort() { + error "$@" + exit 1 +} + +trap 'abort "An unexpected error occurred."' ERR + +getc() { + local save_state + save_state="$(/bin/stty -g)" + /bin/stty raw -echo + IFS='' read -r -n 1 -d '' "$@" + /bin/stty "${save_state}" +} + +wait_for_user() { + local c + echo + echo "Press ${tty_bold}RETURN${tty_reset}/${tty_bold}ENTER${tty_reset} to continue or any other key to abort:" + getc c + # we test for \r and \n because some stuff does \r instead + if ! [[ "${c}" == $'\r' || "${c}" == $'\n' ]] + then + exit 1 + fi +} + +execute() { + ohai "Running: $*" + if ! "$@"; then + abort "Failed during: $*" + fi +} + +have_sudo_access() { + if ! command -v sudo &> /dev/null; then + warn "sudo command not found. Please install sudo or run as root." + return 1 + fi + if [ "$EUID" -ne 0 ]; then + if ! sudo -n true 2>/dev/null; then + warn "This script requires sudo access to install packages. Please run as root or ensure your user has sudo privileges." + return 1 + fi + fi + return 0 +} + +execute_sudo() { + if have_sudo_access; then + ohai "sudo $*" + if ! sudo "$@"; then + abort "Failed to execute: sudo $*" + fi + else + warn "Sudo access is required, attempting to run without sudo" + ohai "$*" + if ! "$@"; then + abort "Failed to execute: $*" + fi + fi +} + +# Function to set or replace environment variables in bash_profile +set_or_replace_env_var() { + local var_name="$1" + local var_value="$2" + local profile_file="$3" + + # Escape special characters for sed + local escaped_var_value=$(printf '%s\n' "$var_value" | sed -e 's/[\/&]/\\&/g') + + if grep -q "^export $var_name=" "$profile_file"; then + # Variable exists, replace it + sed -i.bak "s/^export $var_name=.*/export $var_name=\"$escaped_var_value\"/" "$profile_file" + else + # Variable does not exist, append it + echo "export $var_name=\"$var_value\"" >> "$profile_file" + fi +} + +# Clear the screen and display the logo +clear +echo "" +echo "" +echo " ______ _____ _______ ______ _______ _______ __ _ __ _" +echo " |_____] | | | | ____/ | | | |_____| | \ | | \ |" +echo " |_____] |_____| |_____ | /_____ | | | | | | \_| | \_|" +echo " " +echo "" +echo "" + +echo "This script will do the following:" +echo "1. Install required software (Git, npm, pm2, Python 3.12)" +echo "2. Set up AWS credentials" +echo "3. Clone and set up the Boltzmann repository" +echo "4. Create and register Bittensor wallets" +echo "5. Configure wandb for logging" +echo "6. Clean the specified S3 bucket" +echo "7. Start Boltzmann miners on available GPUs" +echo "" +echo "Please ensure you have a stable internet connection and sufficient permissions to install software." +echo "" + +wait_for_user + +# Ensure ~/.bash_profile exists +touch ~/.bash_profile +source ~/.bash_profile + +# Backup the bash_profile +cp ~/.bash_profile ~/.bash_profile.bak + +# Prompt the user for AWS credentials if not supplied via command-line +ohai "Getting AWS credentials ..." +if [[ -z "$AWS_ACCESS_KEY_ID" ]] || [[ -z "$AWS_SECRET_ACCESS_KEY" ]] || [[ -z "$BUCKET" ]]; then + # TODO: Consider securely storing AWS credentials rather than storing them in plain text + warn "This script will store your AWS credentials in your ~/.bash_profile file." + warn "This is not secure and is not recommended." + read -p "Do you want to proceed? [y/N]: " proceed + if [[ "$proceed" != "y" && "$proceed" != "Y" ]]; then + abort "Aborted by user." + fi + + if [[ -z "$AWS_ACCESS_KEY_ID" ]]; then + read -p "Enter your AWS Access Key ID: " AWS_ACCESS_KEY_ID + fi + if [[ -z "$AWS_SECRET_ACCESS_KEY" ]]; then + read -p "Enter your AWS Secret Access Key: " AWS_SECRET_ACCESS_KEY + fi + if [[ -z "$BUCKET" ]]; then + read -p "Enter your S3 Bucket Name: " BUCKET + fi +fi + +# Overwrite or add the AWS credentials in the bash_profile +set_or_replace_env_var "AWS_ACCESS_KEY_ID" "$AWS_ACCESS_KEY_ID" ~/.bash_profile +set_or_replace_env_var "AWS_SECRET_ACCESS_KEY" "$AWS_SECRET_ACCESS_KEY" ~/.bash_profile +set_or_replace_env_var "BUCKET" "$BUCKET" ~/.bash_profile + +# Source the bash_profile to apply the changes +source ~/.bash_profile +pdone "AWS credentials set in ~/.bash_profile" + +ohai "Installing requirements ..." +# Install Git if not present +if ! command -v git &> /dev/null; then + ohai "Git not found. Installing git ..." + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + ohai "Detected Linux" + if [ -f /etc/os-release ]; then + . /etc/os-release + if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then + ohai "Detected Ubuntu, installing Git..." + if [[ "$DEBUG" == "true" ]]; then + execute_sudo apt-get update -y + execute_sudo apt-get install git -y + else + execute_sudo apt-get update -y > /dev/null 2>&1 + execute_sudo apt-get install git -y > /dev/null 2>&1 + fi + else + warn "Unsupported Linux distribution: $ID" + abort "Cannot install Git automatically" + fi + else + warn "Cannot detect Linux distribution" + abort "Cannot install Git automatically" + fi + else + abort "Unsupported OS type: $OSTYPE" + fi +else + pdone "Git is already installed" +fi + +# TODO: Add error handling for package installations +# TODO: Ensure compatibility with different package managers + +# Check for Rust installation +if ! command -v rustc &> /dev/null; then + ohai "Installing Rust ..." + if [[ "$DEBUG" == "true" ]]; then + execute curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + else + execute curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y > /dev/null 2>&1 + fi + # Add Rust to the PATH for the current session + source $HOME/.cargo/env +fi +pdone "Rust is installed" + +# Install uv if not present +if ! command -v uv &> /dev/null; then + ohai "Installing uv ..." + if [[ "$DEBUG" == "true" ]]; then + execute curl -LsSf https://astral.sh/uv/install.sh | sh + else + execute curl -LsSf https://astral.sh/uv/install.sh | sh > /dev/null 2>&1 + fi + # Add uv to the PATH for the current session + export PATH="$HOME/.cargo/bin:$PATH" +fi +pdone "uv is installed" + +# Check if npm is installed +if ! command -v npm &> /dev/null; then + ohai "Installing npm ..." + if ! command -v node &> /dev/null; then + ohai "Node.js could not be found, installing..." + if ! curl -fsSL https://deb.nodesource.com/setup_18.x | bash; then + abort "Failed to download Node.js setup script" + fi + if ! execute_sudo apt-get install -y nodejs; then + abort "Failed to install Node.js" + fi + fi + if ! curl -L https://www.npmjs.com/install.sh | sh; then + abort "Failed to install npm" + fi +fi +pdone "npm is installed" + +# Install pm2 +if ! command -v pm2 &> /dev/null; then + ohai "Installing pm2 ..." + if [[ "$DEBUG" == "true" ]]; then + execute npm install pm2 -g + else + execute npm install pm2 -g > /dev/null 2>&1 + fi +fi +pdone "pm2 is installed" + +ohai "Installing Boltzmann ..." +# Check if we are inside the boltzmann repository +if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then + REPO_PATH="." +else + if [ ! -d "boltzmann" ]; then + ohai "Cloning boltzmann ..." + execute git clone https://github.com/unconst/boltzmann + REPO_PATH="boltzmann/" + else + REPO_PATH="boltzmann/" + fi +fi +pdone "Boltzmann repository is ready at $REPO_PATH" + +# Install Python 3.12 if not installed +if ! command -v python3.12 &> /dev/null; then + ohai "Installing python3.12 ..." + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + ohai "Detected Linux" + if [ -f /etc/os-release ]; then + . /etc/os-release + if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then + ohai "Detected Ubuntu, installing Python 3.12..." + if [[ "$DEBUG" == "true" ]]; then + if have_sudo_access; then + execute_sudo add-apt-repository ppa:deadsnakes/ppa -y + else + warn "Skipping add-apt-repository due to lack of sudo access" + fi + execute_sudo apt-get update -y + else + if have_sudo_access; then + execute_sudo add-apt-repository ppa:deadsnakes/ppa -y > /dev/null 2>&1 + else + warn "Skipping add-apt-repository due to lack of sudo access" + fi + execute_sudo apt-get update -y > /dev/null 2>&1 + execute_sudo apt-get install --reinstall python3-apt > /dev/null 2>&1 + execute_sudo apt-get install python3.12 -y > /dev/null 2>&1 + execute_sudo apt-get install python3.12-venv > /dev/null 2>&1 + fi + else + warn "Unsupported Linux distribution: $ID" + abort "Cannot install Python 3.12 automatically" + fi + else + warn "Cannot detect Linux distribution" + abort "Cannot install Python 3.12 automatically" + fi + else + abort "Unsupported OS type: $OSTYPE" + fi +fi +pdone "Python 3.12 is installed" + +# Create a virtual environment if it does not exist +if [ ! -d "$REPO_PATH/venv" ]; then + ohai "Creating virtual environment at $REPO_PATH..." + if [[ "$DEBUG" == "true" ]]; then + execute uv venv "$REPO_PATH/.venv" + else + execute uv venv "$REPO_PATH/.venv" > /dev/null 2>&1 + fi +fi +pdone "Virtual environment is set up at $REPO_PATH" + + +# Activate the virtual environment +ohai "Activating virtual environment ..." +source $REPO_PATH/.venv/bin/activate +pdone "Virtual environment activated" + +ohai "Installing Python requirements ..." +if [[ "$DEBUG" == "true" ]]; then + execute uv pip install -r $REPO_PATH/requirements.txt + execute uv pip install --upgrade cryptography pyOpenSSL +else + execute uv pip install -r $REPO_PATH/requirements.txt > /dev/null 2>&1 + execute uv pip install --upgrade cryptography pyOpenSSL > /dev/null 2>&1 +fi +pdone "Python requirements installed" + +# Check for GPUs +ohai "Checking for GPUs..." +if ! command -v nvidia-smi &> /dev/null; then + warn "nvidia-smi command not found. Please ensure NVIDIA drivers are installed." + NUM_GPUS=0 +else + NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + + if [ "$NUM_GPUS" -gt 0 ]; then + nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | while read -r memory; do + pdone "Found GPU with $((memory / 1024)) GB of memory" + done + else + warn "No GPUs found on this machine." + fi +fi + +# Check system RAM +if command -v free &> /dev/null; then + TOTAL_RAM=$(free -g | awk '/^Mem:/{print $2}') + pdone "System RAM: ${TOTAL_RAM} GB" +else + warn "Cannot determine system RAM. 'free' command not found." +fi + +ohai "Creating wallets ..." +# Create the default key +if ! python3 -c "import bittensor as bt; w = bt.wallet(); print(w.coldkey_file.exists_on_device())" | grep -q "True"; then + execute btcli w new_coldkey --wallet.path ~/.bittensor/wallets --wallet.name default --n-words 12 +fi +pdone "Wallet 'default' is ready" + +# Ensure btcli is installed +if ! command -v btcli &> /dev/null; then + abort "btcli command not found. Please ensure it is installed." +fi + +# Create hotkeys and register them +if [ "$NUM_GPUS" -gt 0 ]; then + for i in $(seq 0 $((NUM_GPUS - 1))); do + # Check if the hotkey file exists on the device + exists_on_device=$(python3 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); print(w.hotkey_file.exists_on_device())" 2>/dev/null) + if [ "$exists_on_device" != "True" ]; then + echo "n" | btcli wallet new_hotkey --wallet.name default --wallet.hotkey C$i --n-words 12 > /dev/null 2>&1; + fi + pdone "Created Hotkey 'C$i'" + + # Check if the hotkey is registered on subnet 220 + is_registered=$(python3 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); sub = bt.subtensor('test'); print(sub.is_hotkey_registered_on_subnet(hotkey_ss58=w.hotkey.ss58_address, netuid=220))" 2>/dev/null) + if [[ "$is_registered" != *"True"* ]]; then + ohai "Registering hotkey 'C$i' on subnet 220" + btcli subnet pow_register --wallet.name default --wallet.hotkey C$i --netuid 220 --subtensor.network test --no_prompt > /dev/null 2>&1; + fi + pdone "Registered Hotkey 'C$i' on subnet 220" + done +else + warn "No GPUs found. Skipping hotkey creation." + exit +fi +pdone "All hotkeys registered" + +ohai "Logging into wandb..." +execute wandb login +pdone "wandb is configured" + +# Clean the bucket +ohai "Cleaning bucket $BUCKET..." +if [[ "$DEBUG" == "true" ]]; then + execute python3 $REPO_PATH/tools/clean.py --bucket "$BUCKET" +else + execute python3 $REPO_PATH/tools/clean.py --bucket "$BUCKET" > /dev/null 2>&1 +fi +pdone "Bucket '$BUCKET' cleaned" + +# Close down all previous processes and restart them +if pm2 list | grep -q 'online'; then + ohai "Stopping old pm2 processes..." + pm2 delete all + pdone "Old processes stopped" +fi + +# Start all the processes again +if [ "$NUM_GPUS" -gt 0 ]; then + for i in $(seq 0 $((NUM_GPUS - 1))); do + # Adjust GPU index for zero-based numbering + GPU_INDEX=$i + GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | sed -n "$((i + 1))p") + if [ -z "$GPU_MEMORY" ]; then + warn "Could not get GPU memory for GPU $i" + continue + fi + # Determine batch size based on GPU memory + if [ "$GPU_MEMORY" -ge 80000 ]; then + BATCH_SIZE=6 + elif [ "$GPU_MEMORY" -ge 40000 ]; then + BATCH_SIZE=3 + elif [ "$GPU_MEMORY" -ge 20000 ]; then + BATCH_SIZE=1 + else + BATCH_SIZE=1 + fi + ohai "Starting miner on GPU $GPU_INDEX with batch size $BATCH_SIZE..." + if [[ "$DEBUG" == "true" ]]; then + execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" + else + execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" > /dev/null 2>&1 + fi + done +else + warn "No GPUs found. Skipping miner startup." +fi +pdone "All miners started" +pm2 list + +echo "" +pdone "SUCCESS" +echo "" + +# Start logging process 1 +pm2 logs C0 + diff --git a/scripts/start.sh b/scripts/start.sh new file mode 100755 index 0000000..cc621fc --- /dev/null +++ b/scripts/start.sh @@ -0,0 +1,36 @@ +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# Close down all previous processes and restart them. +pm2 sendSignal SIGINT all +pm2 delete all +# Delete items from bucket +BUCKET=${1:-decis} +PROJECT=${2:-aesop} +python3 tools/clean.py --bucket $BUCKET + +# Start all the processes again. +pm2 start validator.py --interpreter python3 --name V1 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey default --bucket $BUCKET --device cuda:0 --use_wandb --project $PROJECT +pm2 start miner.py --interpreter python3 --name M1 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M1 --bucket $BUCKET --device cuda:1 --use_wandb --project $PROJECT +pm2 start miner.py --interpreter python3 --name M2 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M2 --bucket $BUCKET --device cuda:2 --use_wandb --project $PROJECT +pm2 start miner.py --interpreter python3 --name M3 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M3 --bucket $BUCKET --device cuda:3 --use_wandb --project $PROJECT +pm2 start miner.py --interpreter python3 --name M4 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M4 --bucket $BUCKET --device cuda:5 --use_wandb --random --project $PROJECT +pm2 start miner.py --interpreter python3 --name M5 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M5 --bucket $BUCKET --device cuda:6 --use_wandb --random --project $PROJECT +pm2 start miner.py --interpreter python3 --name M6 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M3 --bucket $BUCKET --device cuda:4 --use_wandb --baseline --project $PROJECT + + + diff --git a/scripts/start_distributed.sh b/scripts/start_distributed.sh new file mode 100644 index 0000000..21e39b6 --- /dev/null +++ b/scripts/start_distributed.sh @@ -0,0 +1,43 @@ +# The MIT License (MIT) +# © 2024 Chakana.tech + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# Close down all previous processes and restart them. + +pm2 sendSignal SIGINT all +pm2 delete all +# Delete items from bucket +BUCKET=${1:-cont2} +PROJECT=${2:-aesop} +# python3 tools/clean.py --bucket $BUCKET + +# Number of GPUs to use. +NGPU=${NGPU:-3} +# The master port for distributed training. +MASTER_PORT=${MASTER_PORT:-29500} +# Which rank should be logged +LOG_RANK=${LOG_RANK:-0} +# Uncomment for debugging. +# export TORCHELASTIC_ERROR_FILE=error.log +# export TORCHELASTIC_DEBUG=1 +# export PYTHONFAULTHANDLER=1 + +# Start the miner using torchrun with distributed training. +pm2 start "torchrun --nproc_per_node=${NGPU} \ + --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank \ + --tee 3 miner.py -- --actual_batch_size 6 --wallet.name Bistro \ + --wallet.hotkey M111 --bucket $BUCKET --device cuda:1 --use_wandb --project $PROJECT --debug" --name M1 --interpreter none From 4c5db4d837e73df56a03af9e13f0de11e96b9313 Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Mon, 28 Oct 2024 02:30:11 +0400 Subject: [PATCH 5/9] chore: refactor --- README.md | 2 +- common.py | 502 ---------------------------------- dataset.py | 483 --------------------------------- docker_run.sh | 387 -------------------------- hparams.py | 87 ------ miner.py | 4 +- run.sh | 555 -------------------------------------- start.sh | 36 --- tests/eval.py | 2 +- tests/legacy_miner.py | 6 +- tests/legacy_validator.py | 6 +- validator.py | 6 +- 12 files changed, 13 insertions(+), 2063 deletions(-) delete mode 100644 common.py delete mode 100644 dataset.py delete mode 100644 docker_run.sh delete mode 100644 hparams.py delete mode 100755 run.sh delete mode 100755 start.sh diff --git a/README.md b/README.md index afcbd51..642fbf2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ --- ```bash -/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/unconst/boltzmann/master/run.sh)" +/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/unconst/boltzmann/master/scripts/run.sh)" ``` --- diff --git a/common.py b/common.py deleted file mode 100644 index 5f12be6..0000000 --- a/common.py +++ /dev/null @@ -1,502 +0,0 @@ -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -import os -import io -import sys -import uuid -import time -import fcntl -import torch -import uvloop -import hashlib -import asyncio -import logging -import tempfile -import aiofiles -import numpy as np -import aiobotocore -import bittensor as bt -import botocore.config -from typing import List, Dict -from dotenv import dotenv_values -from types import SimpleNamespace -from rich.logging import RichHandler -from filelock import FileLock, Timeout -from aiobotocore.session import get_session -from rich.highlighter import NullHighlighter - -# Configure loguru logger -FORMAT = "%(message)s" -logging.basicConfig( - level=logging.INFO, - format=FORMAT, - datefmt="[%X]", - handlers=[ - RichHandler( - markup=True, - rich_tracebacks=True, - highlighter=NullHighlighter(), - show_level=False, - show_time=False, - show_path=False - ) - ] -) -logger = logging.getLogger("rich") -logger.setLevel(logging.INFO) -def debug(): - logger.setLevel(logging.DEBUG) -def trace(): - logger.setLevel(logging.TRACE) -# Log helper. -def T(): return time.time() -def P( w, d ): return f"[steel_blue]{w}[/steel_blue] ([grey63]{d:.2f}s[/grey63])" - -# Load environment variables -env_config = {**dotenv_values(".env"), **os.environ} -AWS_ACCESS_KEY_ID = env_config.get('AWS_ACCESS_KEY_ID') -AWS_SECRET_ACCESS_KEY = env_config.get('AWS_SECRET_ACCESS_KEY') - -# Configure the S3 client -client_config = botocore.config.Config( - max_pool_connections=256, -) - -# Set uvloop as the event loop policy -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -# Define a semaphore to limit concurrent downloads (adjust as needed) -semaphore = asyncio.Semaphore(1000) - -async def get_slices( filename:str, device:str ) -> Dict[str, torch.Tensor]: - # Attempt to acquire the lock with a timeout of 1 second. - lock: FileLock = FileLock(f"{filename}.lock") - with lock.acquire(timeout=5): - pass - return torch.load( - filename, - map_location=torch.device(device), - weights_only = True, - ) - -async def apply_slices_to_model(model: torch.nn.Module, window: int, seed: str, compression: int, key:str = 'slice') -> List[str]: - """ - Applies slices from a specific window to the given model. - - Args: - model (torch.nn.Module): The PyTorch model to which the slices will be applied. - window (int): The window identifier. - seed (str): The seed used for generating indices. - compression (int): The compression factor. - - Returns: - List[str]: A list of all the slice files that were applied. - """ - # First get the indices associated with the window given the model. - indices_dict = await get_indices_for_window(model, seed, compression) - - # Load all the slices associated with this window. - slice_files = await load_files_for_window(window=window, key = key) - - # Dictionary to keep track of the number of slices applied per parameter. - slices_per_param = {name: 0 for name, _ in model.named_parameters()} - - # Dictionary to accumulate the sum of values for each parameter. - param_sums = {name: torch.zeros_like(param.data) for name, param in model.named_parameters()} - - # Iterate over each slice file and compute the sum of values. - for file_i in slice_files: - # Create a file lock to ensure exclusive access to the slice file. - try: - slice_i = await get_slices(file_i, model.device) - for name, param in model.named_parameters(): - if name not in indices_dict or name not in slice_i: - continue - values = slice_i[name].to(param.data.device) - param_indices = indices_dict[name].to(param.data.device) - param_sums[name].view(-1)[param_indices] += values - slices_per_param[name] += 1 - del values - del slice_i - except Timeout: - # The lock could not be acquired within the timeout. - logger.error(f"Timeout occurred while trying to acquire lock on {file_i}") - continue - except Exception as e: - logger.exception(f"Error applying slice from {file_i}: {e}") - - # Apply the average to the parameters. - for name, param in model.named_parameters(): - if name not in slices_per_param or name not in indices_dict or slices_per_param[name] == 0: - continue - param_indices = indices_dict[name].to(param.data.device) - avg_param = param_sums[name].view(-1)[param_indices] / slices_per_param[name] - avg_param = avg_param.to(param.data.dtype) - avg_param = avg_param.to(param.data.device) - param.data.view(-1)[param_indices] = avg_param.clone() - - # Return the list of the files applied. - return slice_files - -async def upload_slice_for_window(bucket: str, model: torch.nn.Module, window: int, seed: str, wallet: 'bt.wallet', compression: int, key:str = 'slice'): - """ - Uploads a compressed slice of a PyTorch model to an S3 bucket. - - Args: - bucket (str): Name of the S3 bucket. - model (torch.nn.Module): The PyTorch model to be sliceed and uploaded. - window (int): The window identifier. - wallet (bt.wallet): The wallet object containing the hotkey. - compression (int): The compression factor. - """ - filename = f'{key}-{window}-{wallet.hotkey.ss58_address}.pt' - logger.debug(f"Uploading slice to S3: {filename}") - - model_state_dict = model.state_dict() - indices = await get_indices_for_window(model, seed, compression) - - # Apply the slice to the model parameters - for name, param in model.named_parameters(): - model_state_dict[name] = param.data.view(-1)[indices[name].to(model.device)].cpu() - - # Create a temporary file and write the sliceed model state dictionary to it - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - torch.save(model_state_dict, temp_file) - temp_file_name = temp_file.name # Store the temporary file name - - # Upload the file to S3 - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - try: - with open(temp_file_name, 'rb') as f: - await s3_client.put_object(Bucket=bucket, Key=filename, Body=f) - # Set the object ACL to public-read - await s3_client.put_object_acl( - Bucket=bucket, - Key=filename, - ACL='public-read' - ) - logger.debug(f"Successfully uploaded slice to S3: {filename}") - except Exception: - logger.exception(f"Failed to upload slice {filename} to S3") - finally: - # Clean up the temporary file - os.remove(temp_file_name) - logger.debug(f"Temporary file {temp_file_name} removed") - -async def upload_master(bucket: str, model: torch.nn.Module, wallet: 'bt.wallet'): - """ - Uploads the master PyTorch model to an S3 bucket. - - Args: - bucket (str): Name of the S3 bucket. - model (torch.nn.Module): The PyTorch model to be uploaded. - wallet (bt.wallet): The wallet object containing the hotkey. - """ - upload_filename = f'master-{wallet.hotkey.ss58_address}.pt' - logger.debug(f"Uploading master model to S3: {upload_filename}") - - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - try: - # Create a temporary file and write the model state dictionary to it - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - torch.save(model.state_dict(), temp_file) - temp_file_name = temp_file.name - - # Upload the file to S3 - with open(temp_file_name, 'rb') as f: - await s3_client.put_object(Bucket=bucket, Key=upload_filename, Body=f) - # Set the object ACL to public-read - await s3_client.put_object_acl( - Bucket=bucket, - Key=upload_filename, - ACL='public-read' - ) - logger.debug(f"Successfully uploaded master model to S3: {upload_filename}") - except Exception: - logger.exception(f"Failed to upload master model {upload_filename} to S3") - finally: - # Clean up the temporary file - os.remove(temp_file_name) - logger.debug(f"Temporary file {temp_file_name} removed") - -async def get_indices_for_window(model: torch.nn.Module, seed: str, compression: int) -> Dict[str, torch.LongTensor]: - """ - Computes the indices for the given window and compression factor. - - Args: - model (torch.nn.Module): The PyTorch model. - seed (str): The window seed identifier. - compression (int): The compression factor. - - Returns: - Dict[str, torch.LongTensor]: A dictionary mapping parameter names to index tensors. - """ - logger.debug(f"Computing indices for window seed {seed} with compression {compression}") - result = {} - # Seed the random number generator with the seed - seed = int(hashlib.md5(str(seed).encode('utf-8')).hexdigest(), 16) % (2**32) - rng = np.random.default_rng(seed) - for name, param in model.named_parameters(): - # Randomly select indices based on the compression factor - num_indices = max(1, int(param.numel() // compression)) - indices = rng.choice(param.numel(), size=num_indices, replace=False) - result[name] = torch.from_numpy(indices).long().cpu() - return result - -async def download_file(s3_client, bucket: str, filename: str) -> str: - """ - Downloads a file from S3, using parallel downloads for large files. - - Args: - s3_client: The S3 client. - bucket (str): Name of the S3 bucket. - filename (str): The S3 object key (filename). - - Returns: - str: The path to the downloaded file in the temporary directory. - """ - async with semaphore: - temp_file = os.path.join(tempfile.gettempdir(), filename) - # Check if the file exists. - if os.path.exists(temp_file): - logger.debug(f"File {temp_file} already exists, skipping download.") - return temp_file - lock_file = f"{temp_file}.lock" - lock = FileLock(lock_file) - try: - # Try to acquire both locks with a timeout - with lock.acquire(timeout=1): - # Proceed to download the file - logger.debug(f"Downloading file {filename} to {temp_file}") - head_response = await s3_client.head_object(Bucket=bucket, Key=filename) - object_size = head_response['ContentLength'] - CHUNK_SIZE = 1 * 1024 * 1024 # 1 MB - - response = await s3_client.get_object(Bucket=bucket, Key=filename) - async with aiofiles.open(temp_file, 'wb') as outfile: - while True: - chunk = await response['Body'].read(CHUNK_SIZE) - if not chunk: - break - await outfile.write(chunk) - - logger.debug(f"Successfully downloaded file {filename} to {temp_file}") - return temp_file - - except Timeout: - logger.error(f"Timeout occurred while trying to acquire lock on {lock_file}") - return None - except Exception as e: - logger.exception(f"Failed to download file {filename} from bucket {bucket}: {e}") - return None - finally: - # The lock is automatically released when exiting the 'with' block - pass - -async def handle_file(s3_client, bucket: str, filename: str, hotkey: str, window: int): - """ - Handles downloading a single file from S3. - - Args: - s3_client: The S3 client. - bucket (str): Name of the S3 bucket. - filename (str): The S3 object key (filename). - hotkey (str): The hotkey identifier. - window (int): The window identifier. - - Returns: - SimpleNamespace: An object containing file metadata and the path to the downloaded file. - """ - logger.debug(f"Handling file {filename} for window {window} and hotkey {hotkey}") - temp_file = await download_file(s3_client, bucket, filename) - if temp_file: - return SimpleNamespace(bucket=bucket, hotkey=hotkey, filename=filename, window=window, temp_file=temp_file) - return None - -async def process_bucket(s3_client, bucket: str, windows: List[int], key:str = 'slice'): - """ - Processes an S3 bucket to download files matching the given windows. - - Args: - s3_client: The S3 client. - bucket (str): Name of the S3 bucket. - windows (List[int]): A list of window identifiers. - - Returns: - List[SimpleNamespace]: A list of file metadata and paths for downloaded files. - """ - logger.debug(f"Processing bucket {bucket} for window {windows}") - files = [] - paginator = s3_client.get_paginator('list_objects_v2') - - for window in windows: - prefix = f'{key}-{window}' - logger.debug(f"Listing objects with prefix {prefix}") - async for page in paginator.paginate(Bucket=bucket, Prefix=prefix): - logger.trace(f"Processing page for prefix {prefix}") - if 'Contents' not in page: - logger.trace(f"No contents found for prefix {prefix}") - continue - download_tasks = [] - for obj in page.get('Contents', []): - filename = obj['Key'] - logger.trace(f"Processing object with key {filename}") - try: - parts = filename.split('-') - slice_window = int(parts[1]) - slice_hotkey = parts[2].split('.')[0] - logger.trace(f"Parsed filename {filename} into window {slice_window} and hotkey {slice_hotkey}") - if slice_window == window: - download_tasks.append(handle_file(s3_client, bucket, filename, slice_hotkey, slice_window)) - except Exception: - logger.exception(f"Error processing filename {filename}") - continue - # Download the files concurrently - results = await asyncio.gather(*download_tasks) - files.extend([res for res in results if res]) - logger.trace(f"Completed processing page for prefix {prefix}") - logger.trace(f"Completed processing bucket {bucket} for windows {windows}") - return files - -async def download_slices_for_buckets_and_windows(buckets: List[str], windows: List[int], key:str = 'slice') -> Dict[int, List[SimpleNamespace]]: - """ - Downloads files from multiple S3 buckets for the given windows. - - Args: - buckets (List[str]): A list of S3 bucket names. - windows (List[int]): A list of window identifiers. - - Returns: - Dict[int, List[SimpleNamespace]]: A dictionary mapping windows to lists of file metadata and paths. - """ - logger.debug(f"Downloading files for buckets {set(buckets)} and windows {windows}") - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - tasks = [] - for bucket in set(buckets): - if not bucket: - continue - tasks.append(process_bucket(s3_client, bucket, windows, key)) - results = await asyncio.gather(*tasks) - # Flatten the list of lists - files = [item for sublist in results for item in sublist] - - # Create a dictionary with windows as keys and list of files as values - windows_dict = {} - for file in files: - window = file.window - if window not in windows_dict: - windows_dict[window] = [] - windows_dict[window].append(file) - - logger.debug(f"Downloaded all files grouped by windows: {windows}") - return windows_dict - -async def load_files_for_window(window: int, key: str = 'slice') -> List[str]: - """ - Retrieves the paths to downloaded window files from the temporary directory. - - Args: - window (int): The window identifier. - - Returns: - List[str]: A list of file paths corresponding to the window. - """ - logger.debug(f"Retrieving files for window {window} from temporary directory") - temp_dir = tempfile.gettempdir() - window_files = [] - for filename in os.listdir(temp_dir): - if filename.startswith(f"{key}-{window}-") and filename.endswith(".pt"): - window_files.append(os.path.join(temp_dir, filename)) - logger.debug(f"Found file {filename} for window {window}") - return window_files - -async def delete_files_before_window(window_max: int, key:str = 'slice'): - """ - Deletes all files on the local machine which have a window id before a specific value window_max. - - Args: - window_max (int): The maximum window id. Files with window ids less than this value will be deleted. - """ - logger.debug(f"Deleting files with window id before {window_max}") - temp_dir = tempfile.gettempdir() - for filename in os.listdir(temp_dir): - if filename.startswith(f"{key}-") and ( filename.endswith(".pt") or filename.endswith(".lock") ): - try: - parts = filename.split('-') - window_id = int(parts[1]) - if window_id < window_max: - if os.path.exists(filename): - os.remove(filename) - logger.debug(f"Deleted file {filename}") - except Exception as e: - logger.error(f"Error deleting file {filename}: {e}") - -async def delete_files_from_bucket_before_window(bucket: str, window_max: int, key: str = 'slice'): - """ - Deletes all files in the specified S3 bucket which have a window id before a specific value window_max. - - Args: - bucket (str): The name of the S3 bucket. - window_max (int): The maximum window id. Files with window ids less than this value will be deleted. - """ - logger.debug(f"Deleting files in bucket {bucket} with window id before {window_max}") - session = get_session() - async with session.create_client( - 's3', - region_name='us-east-1', - config=client_config, - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY - ) as s3_client: - try: - response = await s3_client.list_objects_v2(Bucket=bucket) - if 'Contents' in response: - for obj in response['Contents']: - filename = obj['Key'] - if filename.startswith(f"{key}-") and filename.endswith(".pt"): - try: - parts = filename.split('-') - window_id = int(parts[1]) - if window_id < window_max: - await s3_client.delete_object(Bucket=bucket, Key=filename) - logger.debug(f"Deleted file {filename} from bucket {bucket}") - except Exception as e: - logger.error(f"Error deleting file {filename} from bucket {bucket}: {e}") - except Exception as e: - logger.error(f"Error listing objects in bucket {bucket}: {e}") diff --git a/dataset.py b/dataset.py deleted file mode 100644 index 26cb18a..0000000 --- a/dataset.py +++ /dev/null @@ -1,483 +0,0 @@ -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -import time -import typing -import random -import requests -import asyncio -import aiohttp -import numpy as np -from tqdm import tqdm -from transformers import AutoTokenizer -from torch.utils.data import IterableDataset - -class SubsetLoader(IterableDataset): - """ - Base class for data-specific subset loader classes. - - # TODO: Make this class abstract - """ - def __init__( - self, - batch_size=None, - sequence_length=None, - num_pages=None, - tokenizer: AutoTokenizer=None, - pack_samples: bool=False, - ): - self.batch_size = batch_size - self.sequence_length = sequence_length - self.num_pages = num_pages - self.tokenizer = tokenizer - self.pack_samples = pack_samples - - self.num_rows_per_page = 100 - - # Buffer to hold pages loaded from the api - self.buffer = [] - - # Buffer to hold pages already loaded into a batch - self.used_buffer = [] - - # Buffer to hold padded pages - self.padded_buffer = [] - - self.lock = asyncio.Lock() # For thread-safe operations - - async def fetch_data_for_pages(self, pages): - """ - Set the pages to be used to fill the buffer. Then fetch the page data - to the buffer. - """ - - self.pages = pages - - # Empty the buffer if it is not. - self.buffer = [] - - async with aiohttp.ClientSession() as session: - tasks = [self._fetch_data_for_page(page, session) for page in self.pages] - await asyncio.gather(*tasks) - - async def _fetch_data_for_page(self, page, session): - retry_limit = 10 - attempt = 0 - while attempt < retry_limit: - config_name, page_number, split = page - - # Create the request parameters - params = dict(dataset=self.name, - config=config_name, - split=split, - offset=page_number, - limit=self.num_rows_per_page - ) - - try: - async with session.get(self.rows_base_url, params=params) as response: - response.raise_for_status() - data = await response.json() - - # Prepare the data to append - buffer_to_append = [] - for row in data["rows"]: - content = row["row"]["text"] - input_ids = self.tokenizer(content, truncation=True)["input_ids"] - buffer_to_append.extend(input_ids) - buffer_to_append.append(self.tokenizer.eos_token_id) - - async with self.lock: - self.buffer.extend(buffer_to_append) - self.pages.append((config_name, page_number, split)) - break # Success, exit retry loop - - except aiohttp.ClientResponseError as e: - attempt += 1 - if attempt < retry_limit: - await asyncio.sleep(5) - else: - raise - - def _get_pad_size(self, input_ids): - """ - Get the number of tokens to be padded to the sample to match - the max allowed sequence length. - If sample packing is activated, then return 1 - """ - - if self.pack_samples: - return 1 - - sample_size = len(input_ids) - - remainder = (sample_size % self.sequence_length) - pad_size = (self.sequence_length - remainder) - - # Apply modulo again to guarantee a pad size of 0 if remainder is 0 - pad_size = pad_size % self.sequence_length - - return pad_size - - def _refill_padded_buffer(self): - """ - This methods pulls one page from `self.buffer`, pads it and pushs - it to the `self.padded_buffer`. - """ - - while ( - self.buffer - and len(self.padded_buffer) < self.sequence_length - ): - - input_ids = [] - - # search for EOS token index and cut the buffer at it. - EOS_index = self.buffer.index(self.tokenizer.eos_token_id) - input_ids = self.buffer[:EOS_index+1] - self.buffer =self.buffer[EOS_index+1:] - - self.used_buffer += input_ids - - # Add to padded buffer without the EOS token. - self.padded_buffer += input_ids[:-1] - - # Pad - self.padded_buffer += [self.tokenizer.eos_token_id] * self._get_pad_size(input_ids=input_ids[:-1]) - - def __iter__(self): - self.buffer = self.used_buffer + self.buffer - self.padded_buffer = [] - - # Pad and prepare one page for batching - self._refill_padded_buffer() - - return self - - def __next__(self): - batch = [] - - while len(self.padded_buffer) >= self.sequence_length: - batch.append(self.padded_buffer[: self.sequence_length]) - self.padded_buffer = self.padded_buffer[self.sequence_length :] - self._refill_padded_buffer() - - if len(batch) == self.batch_size: - return np.stack(batch) - - raise StopIteration - - -class DatasetLoader(SubsetLoader): - - name: str = "HuggingFaceFW/fineweb-edu-score-2" - rows_base_url: str = "https://datasets-server.huggingface.co/rows" - size_base_url: str = "https://datasets-server.huggingface.co/size" - - retry_limit: int = 10 # Number of retries - retry_delay: int = 5 # Seconds to wait between retries - num_rows_per_page: int = 100 - - @staticmethod - async def next_pages(offset: int, n_pages: int, seed: str, num_rows_per_page: int = 100): - configs_data = await DatasetLoader.fetch_dataset_configs() - rng = np.random.default_rng(hash(seed) & 0xffffffff) # Create a generator with a seed - rng.bit_generator.advance(offset) # Efficiently skip ahead `n` steps - result = [] - for _ in range(n_pages): - config = rng.choice(list(configs_data.keys())) - choice = rng.integers(0, configs_data[config]['num_rows'] - 1 - num_rows_per_page) - result.append((str(config), int(choice), configs_data[config]['split'])) - return result - - def __init__( - self, - batch_size=None, - sequence_length=None, - num_pages=None, - pages_info=None, - tokenizer: AutoTokenizer = None, - pack_samples: bool = False, - ): - super().__init__(batch_size, - sequence_length, - num_pages, - tokenizer, - pack_samples) - - # Initialize properties - self.configs_data = None - self.pages = [] - self.buffer = [] - self.lock = asyncio.Lock() # For thread-safe operations - - @classmethod - async def create( - cls, - batch_size=None, - sequence_length=None, - num_pages=None, - pages_info=None, - tokenizer: AutoTokenizer = None, - pack_samples: bool = False, - ): - self = cls( - batch_size=batch_size, - sequence_length=sequence_length, - num_pages=num_pages, - tokenizer=tokenizer, - pack_samples=pack_samples - ) - - # Fetch dataset configs asynchronously - self.configs_data = await cls.fetch_dataset_configs() - - if pages_info is not None: - await self._fetch(pages_info) - elif self.num_pages: - await self._fetch_data_to_buffer(self.num_pages) - - return self - - async def _fetch(self, page_info: typing.Tuple[str, int, str]): - self.pages = list(page_info) - num_pages = len(self.pages) - async with aiohttp.ClientSession() as session: - tasks = [self._fetch_data_for_page((config_name, page, split), session) - for (config_name, page, split) in self.pages] - await asyncio.gather(*tasks) - - async def _fetch_data_to_buffer(self, num_pages): - """ - Randomly sample pages and add their data to the buffer. - If a page is inaccessible, another one is sampled. - This method sets the `pages` property. - """ - self.pages = [] - pages_to_fetch = self.get_random_pages(num_pages) - - async with aiohttp.ClientSession() as session: - tasks = [self._fetch_data_for_page(page, session) for page in pages_to_fetch] - await asyncio.gather(*tasks) - - async def fetch_data_to_rows(self, num_pages): - rows = [] - pages_to_fetch = self.get_random_pages(num_pages) - - async with aiohttp.ClientSession() as session: - tasks = [self._fetch_rows_for_page(page, session) for page in pages_to_fetch] - results = await asyncio.gather(*tasks) - for page_rows in results: - rows.extend(page_rows) - - return rows - - - async def _fetch_data_for_page(self, page, session): - """ - Fetches data asynchronously for a single page, processes it without blocking the event loop, - and appends the tokenized data to the buffer. - - Args: - page: A tuple containing the config name, page number, and split. - session: The HTTP session used for making requests. - - Raises: - Exception: If the maximum number of retry attempts is exceeded. - """ - retry_limit = self.retry_limit - attempt = 0 - while attempt < retry_limit: - config_name, page_number, split = page - - # Create the request parameters - params = { - 'dataset': self.name, - 'config': config_name, - 'split': split, - 'offset': page_number, - 'limit': self.num_rows_per_page - } - - try: - # Make an asynchronous HTTP GET request to fetch the data - async with session.get(self.rows_base_url, params=params) as response: - response.raise_for_status() # Raise an exception for HTTP errors - data = await response.json() - - # Prepare the data to append - buffer_to_append = [] - - # Asynchronously process each row without blocking the event loop - tasks = [ - self._tokenize_content(row["row"]["text"]) for row in data["rows"] - ] - - # Gather the tokenized results concurrently - row_input_ids = await asyncio.gather(*tasks) - - # Flatten the list of input IDs and append them to the buffer - for input_ids in row_input_ids: - buffer_to_append.extend(input_ids) - - # Safely append the processed data to the shared buffer - async with self.lock: - self.buffer.extend(buffer_to_append) - self.pages.append((config_name, page_number, split)) - break # Success, exit retry loop - - except aiohttp.ClientResponseError as e: - # Handle HTTP client errors with a retry mechanism - attempt += 1 - if attempt < retry_limit: - await asyncio.sleep(self.retry_delay) # Wait before retrying - else: - raise Exception(f"Maximum retry attempts exceeded for page {page}") from e - - async def _tokenize_content(self, content): - """ - Asynchronously tokenizes a string of content using the tokenizer in a separate thread. - - Args: - content: The text content to be tokenized. - - Returns: - The list of token IDs for the content, including the EOS token. - """ - # Offload the CPU-bound tokenization to a thread executor to prevent blocking the event loop - input_ids = await asyncio.to_thread( - self.tokenizer.encode, content, truncation=True, max_length=self.sequence_length - ) - input_ids.append(self.tokenizer.eos_token_id) - return input_ids - - async def _fetch_rows_for_page(self, page, session): - retry_limit = self.retry_limit - attempt = 0 - while attempt < retry_limit: - config_name, page_number, split = page - - # Create the request parameters - params = dict(dataset=self.name, - config=config_name, - split=split, - offset=page_number, - limit=self.num_rows_per_page - ) - - try: - async with session.get(self.rows_base_url, params=params) as response: - response.raise_for_status() - data = await response.json() - - # Collect the rows - return [row["row"]["text"] for row in data["rows"]] - - except aiohttp.ClientResponseError as e: - attempt += 1 - if attempt < retry_limit: - await asyncio.sleep(self.retry_delay) - else: - raise - - def get_random_pages(self, num_pages): - """ - Randomly sample pages. - A page is a row number of a given split of a given dataset dump. - """ - pages = [] - - for _ in range(num_pages): - # Choose a random config - config_name = random.choice(list(self.configs_data.keys())) - - # Choose a random page (row) - page = random.randint(0, - self.configs_data[config_name]['num_rows'] - 1 - self.num_rows_per_page) - - split = self.configs_data[config_name]['split'] - - pages.append((config_name, page, split)) - - return pages - - def get_page_names(self): - """ - This is a utility function that returns the page names that were used. - Each page as a single string instead of a tuple. - """ - page_names = [] - - if hasattr(self, 'pages'): - page_names = [f'{cfg_name}_{num_rows}_{split}' for - cfg_name, num_rows, split in self.pages] - - return page_names - - @staticmethod - async def fetch_dataset_configs() -> typing.Dict[str, typing.Dict]: - """ - Fetch the different dump names, aka configs, aka samples, of the - dataset. - The returned value is a dictionary with dump names as keys and - a dict of the number of rows and the split as values. - """ - # Request parameters - params = dict( - dataset=DatasetLoader.name - ) - - attempt = 0 - while attempt < DatasetLoader.retry_limit: - try: - async with aiohttp.ClientSession() as session: - async with session.get(DatasetLoader.size_base_url, params=params) as response: - response.raise_for_status() - - data = await response.json() - - # Extract the configs dict - configs_dict = data['size']['splits'] - - # Now create a dict with config names (except 'default') as - # keys, and the number of rows as values - configs_data = {entry['config']: {'num_rows': entry['num_rows'], - 'split': entry['split']} - for entry in configs_dict - if entry['config'] != 'default' - } - - return configs_data - - except aiohttp.ClientResponseError as e: - attempt += 1 - if attempt < DatasetLoader.retry_limit: - await asyncio.sleep(DatasetLoader.retry_delay) - else: - raise - - @staticmethod - async def next_pages_async(offset: int, n_pages: int, seed: str, num_rows_per_page: int = 100): - configs_data = await DatasetLoader.fetch_dataset_configs() - rng = np.random.default_rng(hash(seed) & 0xffffffff) # Create a generator with a seed - rng.bit_generator.advance(offset) # Efficiently skip ahead `n` steps - result = [] - for _ in range(n_pages): - config = rng.choice(list(configs_data.keys())) - choice = rng.integers(0, configs_data[config]['num_rows'] - 1 - num_rows_per_page) - result.append((str(config), int(choice), configs_data[config]['split'])) - return result diff --git a/docker_run.sh b/docker_run.sh deleted file mode 100644 index c1023e9..0000000 --- a/docker_run.sh +++ /dev/null @@ -1,387 +0,0 @@ -#!/usr/bin/env bash - -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -set -euo pipefail - -trap 'abort "An unexpected error occurred."' ERR - -# Set up colors and styles -if [[ -t 1 ]]; then - tty_escape() { printf "\033[%sm" "$1"; } -else - tty_escape() { :; } -fi -tty_mkbold() { tty_escape "1;$1"; } -tty_blue="$(tty_mkbold 34)" -tty_red="$(tty_mkbold 31)" -tty_green="$(tty_mkbold 32)" -tty_yellow="$(tty_mkbold 33)" -tty_bold="$(tty_mkbold 39)" -tty_reset="$(tty_escape 0)" - -ohai() { - printf "${tty_blue}==>${tty_bold} %s${tty_reset}\n" "$*" -} - -pdone() { - printf "${tty_green}[✔]${tty_bold} %s${tty_reset}\n" "$*" -} - -info() { - printf "${tty_green}%s${tty_reset}\n" "$*" -} - -warn() { - printf "${tty_yellow}Warning${tty_reset}: %s\n" "$*" >&2 -} - -error() { - printf "${tty_red}Error${tty_reset}: %s\n" "$*" >&2 -} - -abort() { - error "$@" - exit 1 -} - -getc() { - local save_state - save_state="$(/bin/stty -g)" - /bin/stty raw -echo - IFS='' read -r -n 1 -d '' "$@" - /bin/stty "${save_state}" -} - -wait_for_user() { - local c - echo - echo "Press ${tty_bold}RETURN${tty_reset}/${tty_bold}ENTER${tty_reset} to continue or any other key to abort:" - getc c - # we test for \r and \n because some stuff does \r instead - if ! [[ "${c}" == $'\r' || "${c}" == $'\n' ]] - then - exit 1 - fi -} - -execute() { - ohai "Running: $*" - if ! "$@"; then - abort "Failed during: $*" - fi -} - -have_root_access() { - if [ "$EUID" -ne 0 ]; then - warn "This script requires root privileges to install packages." - return 1 - fi - return 0 -} - -clear -echo "" -echo "" -echo " ______ _____ _______ ______ _______ _______ __ _ __ _" -echo " |_____] | | | | ____/ | | | |_____| | \ | | \ |" -echo " |_____] |_____| |_____ | /_____ | | | | | | \_| | \_|" -echo " " -echo "" -echo "" - -wait_for_user - -# Install Git if not present -if ! command -v git &> /dev/null; then - ohai "Installing git ..." - if have_root_access; then - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - if [ -f /etc/os-release ]; then - . /etc/os-release - if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then - ohai "Detected Ubuntu, installing Git..." - execute apt-get update -y - execute apt-get install -y git - else - warn "Unsupported Linux distribution: $ID" - abort "Cannot install Git automatically" - fi - else - warn "Cannot detect Linux distribution" - abort "Cannot install Git automatically" - fi - elif [[ "$OSTYPE" == "darwin"* ]]; then - ohai "Detected macOS, installing Git..." - execute xcode-select --install - else - abort "Unsupported OS type: $OSTYPE" - fi - else - abort "Root access is required to install Git." - fi -else - pdone "Found Git" -fi - -# Check if we are inside the cont repository -if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then - REPO_PATH="." -else - if [ ! -d "cont" ]; then - ohai "Cloning boltzmann ..." - execute git clone https://github.com/unconst/cont - REPO_PATH="cont/" - else - REPO_PATH="cont/" - fi -fi -pdone "Pulled Boltzmann $REPO_PATH" - -# Install Node.js and npm if not present -if ! command -v npm &> /dev/null; then - ohai "Installing Node.js and npm ..." - if have_root_access; then - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - execute apt-get update -y - execute apt-get install -y curl - curl -fsSL https://deb.nodesource.com/setup_20.x | bash - - execute apt-get install -y nodejs - elif [[ "$OSTYPE" == "darwin"* ]]; then - ohai "Detected macOS, installing Node.js and npm..." - execute brew install node - else - abort "Unsupported OS type: $OSTYPE" - fi - else - abort "Root access is required to install Node.js and npm." - fi - pdone "Installed Node.js and npm" -else - pdone "Found npm" -fi - -# Install pm2 -if ! command -v pm2 &> /dev/null; then - ohai "Installing pm2 ..." - execute npm install pm2 -g - pdone "Installed pm2" -else - pdone "Found pm2" -fi - -# Install Python 3.12 if not installed -if ! command -v python3.12 &> /dev/null; then - ohai "Installing python3.12 ..." - if have_root_access; then - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - if [ -f /etc/os-release ]; then - . /etc/os-release - if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then - ohai "Detected Ubuntu, installing Python 3.12..." - execute apt-get update -y - execute apt-get install -y software-properties-common gnupg - - # Add the deadsnakes PPA manually - ohai "Adding deadsnakes PPA manually..." - execute mkdir -p /etc/apt/keyrings - execute curl -fsSL "https://keyserver.ubuntu.com/pks/lookup?op=get&search=0x6A755776" | gpg --dearmor --batch --yes -o /etc/apt/keyrings/deadsnakes-archive-keyring.gpg - echo "deb [signed-by=/etc/apt/keyrings/deadsnakes-archive-keyring.gpg] http://ppa.launchpad.net/deadsnakes/ppa/ubuntu jammy main" > /etc/apt/sources.list.d/deadsnakes-ppa.list - - execute apt-get update -y - execute apt-get install -y python3.12 python3.12-venv - - else - warn "Unsupported Linux distribution: $ID" - abort "Cannot install Python 3.12 automatically" - fi - else - warn "Cannot detect Linux distribution" - abort "Cannot install Python 3.12 automatically" - fi - elif [[ "$OSTYPE" == "darwin"* ]]; then - ohai "Detected macOS, installing Python 3.12..." - execute brew install python@3.12 - else - abort "Unsupported OS type: $OSTYPE" - fi - else - abort "Root access is required to install Python 3.12." - fi - pdone "Installed python3.12" -else - pdone "Found python3.12" -fi - -touch ~/.bash_profile - -# Prompt the user for AWS credentials and inject them into the bash_profile file if not already stored -if ! grep -q "AWS_ACCESS_KEY_ID" ~/.bash_profile || ! grep -q "AWS_SECRET_ACCESS_KEY" ~/.bash_profile || ! grep -q "BUCKET" ~/.bash_profile; then - clear - warn "This script will store your AWS credentials in your ~/.bash_profile file." - warn "This is not secure and is not recommended." - read -p "Do you want to proceed? [y/N]: " proceed - if [[ "$proceed" != "y" && "$proceed" != "Y" ]]; then - abort "Aborted by user." - fi - - read -p "Enter your AWS Access Key ID: " AWS_ACCESS_KEY_ID - read -p "Enter your AWS Secret Access Key: " AWS_SECRET_ACCESS_KEY - read -p "Enter your S3 Bucket Name: " BUCKET - - echo "export AWS_ACCESS_KEY_ID=\"$AWS_ACCESS_KEY_ID\"" >> ~/.bash_profile - echo "export AWS_SECRET_ACCESS_KEY=\"$AWS_SECRET_ACCESS_KEY\"" >> ~/.bash_profile - echo "export BUCKET=\"$BUCKET\"" >> ~/.bash_profile -fi - -# Source the bash_profile file to apply the changes -source ~/.bash_profile - -pdone "Found AWS credentials" - -# Create a virtual environment if it does not exist -if [ ! -d "$REPO_PATH/venv" ]; then - ohai "Creating virtual environment at $REPO_PATH..." - execute python3.12 -m venv "$REPO_PATH/venv" -fi -pdone "Created venv at $REPO_PATH" - -if [[ -z "${VIRTUAL_ENV:-}" ]]; then - ohai "Activating virtual environment..." - source "$REPO_PATH/venv/bin/activate" -fi -pdone "Activated venv at $REPO_PATH" - -ohai "Installing requirements..." -execute pip install --upgrade pip -execute pip install -r "$REPO_PATH/requirements.txt" -pdone "Installed requirements" - -# Check for GPUs -ohai "Checking for GPUs..." -if ! command -v nvidia-smi &> /dev/null; then - warn "nvidia-smi command not found. Please ensure NVIDIA drivers are installed." - NUM_GPUS=0 -else - NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) - ohai "Number of GPUs: $NUM_GPUS" - - if [ "$NUM_GPUS" -gt 0 ]; then - ohai "GPU Memory Information:" - nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | while read -r memory; do - ohai "$((memory / 1024)) GB" - done - else - warn "No GPUs found on this machine." - fi -fi - -# Check system RAM -ohai "Checking system RAM..." -if command -v free &> /dev/null; then - TOTAL_RAM=$(free -g | awk '/^Mem:/{print $2}') - ohai "Total System RAM: $TOTAL_RAM GB" -else - warn "Cannot determine system RAM. 'free' command not found." -fi - -# Create the default key -ohai "Creating the coldkey" -if ! python3.12 -c "import bittensor as bt; w = bt.wallet(); print(w.coldkey_file.exists_on_device())" | grep -q "True"; then - execute btcli w new_coldkey --wallet.path ~/.bittensor/wallets --wallet.name default --n-words 12 --no_password -else - ohai "Default key already exists on device." -fi - -# Ensure btcli is installed -if ! command -v btcli &> /dev/null; then - abort "btcli command not found. Please ensure it is installed." -fi - -# Create hotkeys and register them -if [ "$NUM_GPUS" -gt 0 ]; then - for i in $(seq 0 $((NUM_GPUS - 1))); do - # Check if the hotkey file exists on the device - exists_on_device=$(python3.12 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); print(w.hotkey_file.exists_on_device())" 2>/dev/null) - if [ "$exists_on_device" != "True" ]; then - echo "n" | btcli wallet new_hotkey --wallet.name default --wallet.hotkey C$i --n-words 12 > /dev/null 2>&1; - else - ohai "Hotkey C$i already exists on device." - fi - - # Check if the hotkey is registered on subnet 220 - is_registered=$(python3.12 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); sub = bt.subtensor('test'); print(sub.is_hotkey_registered_on_subnet(hotkey_ss58=w.hotkey.ss58_address, netuid=220))" 2>/dev/null) - if [[ "$is_registered" != *"True"* ]]; then - ohai "Registering key on subnet 220" - btcli subnet pow_register --wallet.name default --wallet.hotkey C$i --netuid 220 --subtensor.network test --no_prompt > /dev/null 2>&1; - else - ohai "Key is already registered on subnet 220" - fi - done -else - warn "No GPUs found. Skipping hotkey creation." -fi - -ohai "Logging into wandb..." -execute wandb login - -# Delete items from bucket -PROJECT=${2:-aesop} -ohai "Cleaning bucket $BUCKET..." -execute python3.12 "$REPO_PATH/tools/clean.py" --bucket "$BUCKET" - -# Start all the processes again -if [ "$NUM_GPUS" -gt 0 ]; then - for i in $(seq 0 $((NUM_GPUS - 1))); do - # Adjust GPU index for zero-based numbering - GPU_INDEX=$i - GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | sed -n "$((i + 1))p") - if [ -z "$GPU_MEMORY" ]; then - warn "Could not get GPU memory for GPU $i" - continue - fi - # Determine batch size based on GPU memory - # This section adjusts the batch size for the miner based on the available GPU memory - # Higher memory allows for larger batch sizes, which can improve performance - if [ "$GPU_MEMORY" -ge 80000 ]; then - # For GPUs with 80GB or more memory, use a batch size of 6 - BATCH_SIZE=6 - elif [ "$GPU_MEMORY" -ge 40000 ]; then - # For GPUs with 40GB to 79GB memory, use a batch size of 3 - BATCH_SIZE=3 - elif [ "$GPU_MEMORY" -ge 20000 ]; then - # For GPUs with 20GB to 39GB memory, use a batch size of 1 - BATCH_SIZE=1 - else - # For GPUs with less than 20GB memory, also use a batch size of 1 - # This ensures that even lower-end GPUs can still participate - BATCH_SIZE=1 - fi - ohai "Starting miner on GPU $GPU_INDEX with batch size $BATCH_SIZE..." - execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" - done -else - warn "No GPUs found. Skipping miner startup." -fi - -pm2 list - -ohai "Script completed successfully." diff --git a/hparams.py b/hparams.py deleted file mode 100644 index 2dde865..0000000 --- a/hparams.py +++ /dev/null @@ -1,87 +0,0 @@ -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -import os -import json -import time -import requests -from types import SimpleNamespace -from transformers import AutoTokenizer, LlamaConfig - -from common import * - -# Cache file path -HPARAMS_FILE = "hparams.json" - -def create_namespace(hparams: dict) -> SimpleNamespace: - """ - Create a SimpleNamespace from the hyperparameters and add model configuration. - - Args: - hparams (dict): Hyperparameters dictionary. - - Returns: - SimpleNamespace: Namespace containing hyperparameters and model configuration. - """ - hparams_ns = SimpleNamespace(**hparams) - - hparams_ns.tokenizer = AutoTokenizer.from_pretrained( - hparams_ns.tokenizer_name, verbose=False, clean_up_tokenization_spaces=True - ) - hparams_ns.tokenizer.pad_token = hparams_ns.tokenizer.eos_token - - hparams_ns.model_config = LlamaConfig( - vocab_size=hparams_ns.tokenizer.vocab_size, - hidden_size=hparams_ns.hidden_size, - num_hidden_layers=hparams_ns.num_hidden_layers, - num_attention_heads=hparams_ns.num_attention_heads, - intermediate_size=hparams_ns.intermediate_size, - num_key_value_heads=hparams_ns.num_key_value_heads, - activation_function=hparams_ns.activation_function, - max_position_embeddings=hparams_ns.max_position_embeddings, - ) - - return hparams_ns - -def load_hparams() -> SimpleNamespace: - """ - Load hyperparameters from a GitHub file, with caching and fallback mechanisms. - - Returns: - SimpleNamespace: A namespace containing the hyperparameters and model configuration. - - Example: - hparams = load_hparams() - print(hparams.hidden_size) - print(hparams.model_config) - """ - github_url = f"https://raw.githubusercontent.com/unconst/cont/master/hparams.json?timestamp={int(time.time())}" - try: - # Attempt to fetch from the GitHub file first - response = requests.get(github_url, timeout=10, headers={'Cache-Control': 'no-cache'}) - response.raise_for_status() - hparams = json.loads(response.text) - logger.debug("Successfully loaded parameters from GitHub.") - except (requests.RequestException, json.JSONDecodeError) as e: - logger.debug(f"Error loading parameters from GitHub: {e}") - logger.debug("Attempting to load from cache...") - with open(HPARAMS_FILE, "r") as f: - hparams = json.load(f) - # Cache the new parameters - with open(HPARAMS_FILE, "w") as f: - json.dump(hparams, f, indent=4) - return create_namespace(hparams) diff --git a/miner.py b/miner.py index 1b078f6..966890a 100644 --- a/miner.py +++ b/miner.py @@ -41,8 +41,8 @@ # Import local files. from boltz.common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader from boltz.fsdp import fsdp_auto_wrap_policy # GPU optimizations. diff --git a/run.sh b/run.sh deleted file mode 100755 index 9db4a4d..0000000 --- a/run.sh +++ /dev/null @@ -1,555 +0,0 @@ -#!/usr/bin/env bash - -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -set -euo pipefail - -# Initialize default values -DEBUG=false -PROJECT="aesop" -AWS_ACCESS_KEY_ID="" -AWS_SECRET_ACCESS_KEY="" -BUCKET="" - -# Function to display help message -display_help() { - cat << EOF -Usage: $0 [options] - -Options: - --debug Enable debug mode - --project Set the project name (default: aesop) - --aws-access-key-id Set AWS Access Key ID - --aws-secret-access-key Set AWS Secret Access Key - --bucket Set the S3 bucket name - -h, --help Display this help message - -Description: - Installs and runs a Boltzmann miner on your GPU. -EOF -} - -# Parse command-line arguments -while [[ $# -gt 0 ]]; do - key="$1" - case $key in - --debug) - DEBUG=true - shift - ;; - --project) - PROJECT="$2" - shift 2 - ;; - --aws-access-key-id) - AWS_ACCESS_KEY_ID="$2" - shift 2 - ;; - --aws-secret-access-key) - AWS_SECRET_ACCESS_KEY="$2" - shift 2 - ;; - --bucket) - BUCKET="$2" - shift 2 - ;; - -h|--help|-help|--h) - display_help - exit 0 - ;; - *) - echo "Unknown option: $1" - display_help - exit 1 - ;; - esac -done - -# Set up colors and styles -if [[ -t 1 ]]; then - tty_escape() { printf "\033[%sm" "$1"; } -else - tty_escape() { :; } -fi -tty_mkbold() { tty_escape "1;$1"; } -tty_blue="$(tty_mkbold 34)" -tty_red="$(tty_mkbold 31)" -tty_green="$(tty_mkbold 32)" -tty_yellow="$(tty_mkbold 33)" -tty_bold="$(tty_mkbold 39)" -tty_reset="$(tty_escape 0)" - -# Logging functions -ohai() { - printf "${tty_blue}==>${tty_bold} %s${tty_reset}\n" "$*" -} - -pdone() { - printf " ${tty_green}[✔]${tty_bold} %s${tty_reset}\n" "$*" -} - -info() { - printf "${tty_green}%s${tty_reset}\n" "$*" -} - -warn() { - printf "${tty_yellow}Warning${tty_reset}: %s\n" "$*" >&2 -} - -error() { - printf "${tty_red}Error${tty_reset}: %s\n" "$*" >&2 -} - -abort() { - error "$@" - exit 1 -} - -trap 'abort "An unexpected error occurred."' ERR - -getc() { - local save_state - save_state="$(/bin/stty -g)" - /bin/stty raw -echo - IFS='' read -r -n 1 -d '' "$@" - /bin/stty "${save_state}" -} - -wait_for_user() { - local c - echo - echo "Press ${tty_bold}RETURN${tty_reset}/${tty_bold}ENTER${tty_reset} to continue or any other key to abort:" - getc c - # we test for \r and \n because some stuff does \r instead - if ! [[ "${c}" == $'\r' || "${c}" == $'\n' ]] - then - exit 1 - fi -} - -execute() { - ohai "Running: $*" - if ! "$@"; then - abort "Failed during: $*" - fi -} - -have_sudo_access() { - if ! command -v sudo &> /dev/null; then - warn "sudo command not found. Please install sudo or run as root." - return 1 - fi - if [ "$EUID" -ne 0 ]; then - if ! sudo -n true 2>/dev/null; then - warn "This script requires sudo access to install packages. Please run as root or ensure your user has sudo privileges." - return 1 - fi - fi - return 0 -} - -execute_sudo() { - if have_sudo_access; then - ohai "sudo $*" - if ! sudo "$@"; then - abort "Failed to execute: sudo $*" - fi - else - warn "Sudo access is required, attempting to run without sudo" - ohai "$*" - if ! "$@"; then - abort "Failed to execute: $*" - fi - fi -} - -# Function to set or replace environment variables in bash_profile -set_or_replace_env_var() { - local var_name="$1" - local var_value="$2" - local profile_file="$3" - - # Escape special characters for sed - local escaped_var_value=$(printf '%s\n' "$var_value" | sed -e 's/[\/&]/\\&/g') - - if grep -q "^export $var_name=" "$profile_file"; then - # Variable exists, replace it - sed -i.bak "s/^export $var_name=.*/export $var_name=\"$escaped_var_value\"/" "$profile_file" - else - # Variable does not exist, append it - echo "export $var_name=\"$var_value\"" >> "$profile_file" - fi -} - -# Clear the screen and display the logo -clear -echo "" -echo "" -echo " ______ _____ _______ ______ _______ _______ __ _ __ _" -echo " |_____] | | | | ____/ | | | |_____| | \ | | \ |" -echo " |_____] |_____| |_____ | /_____ | | | | | | \_| | \_|" -echo " " -echo "" -echo "" - -echo "This script will do the following:" -echo "1. Install required software (Git, npm, pm2, Python 3.12)" -echo "2. Set up AWS credentials" -echo "3. Clone and set up the Boltzmann repository" -echo "4. Create and register Bittensor wallets" -echo "5. Configure wandb for logging" -echo "6. Clean the specified S3 bucket" -echo "7. Start Boltzmann miners on available GPUs" -echo "" -echo "Please ensure you have a stable internet connection and sufficient permissions to install software." -echo "" - -wait_for_user - -# Ensure ~/.bash_profile exists -touch ~/.bash_profile -source ~/.bash_profile - -# Backup the bash_profile -cp ~/.bash_profile ~/.bash_profile.bak - -# Prompt the user for AWS credentials if not supplied via command-line -ohai "Getting AWS credentials ..." -if [[ -z "$AWS_ACCESS_KEY_ID" ]] || [[ -z "$AWS_SECRET_ACCESS_KEY" ]] || [[ -z "$BUCKET" ]]; then - # TODO: Consider securely storing AWS credentials rather than storing them in plain text - warn "This script will store your AWS credentials in your ~/.bash_profile file." - warn "This is not secure and is not recommended." - read -p "Do you want to proceed? [y/N]: " proceed - if [[ "$proceed" != "y" && "$proceed" != "Y" ]]; then - abort "Aborted by user." - fi - - if [[ -z "$AWS_ACCESS_KEY_ID" ]]; then - read -p "Enter your AWS Access Key ID: " AWS_ACCESS_KEY_ID - fi - if [[ -z "$AWS_SECRET_ACCESS_KEY" ]]; then - read -p "Enter your AWS Secret Access Key: " AWS_SECRET_ACCESS_KEY - fi - if [[ -z "$BUCKET" ]]; then - read -p "Enter your S3 Bucket Name: " BUCKET - fi -fi - -# Overwrite or add the AWS credentials in the bash_profile -set_or_replace_env_var "AWS_ACCESS_KEY_ID" "$AWS_ACCESS_KEY_ID" ~/.bash_profile -set_or_replace_env_var "AWS_SECRET_ACCESS_KEY" "$AWS_SECRET_ACCESS_KEY" ~/.bash_profile -set_or_replace_env_var "BUCKET" "$BUCKET" ~/.bash_profile - -# Source the bash_profile to apply the changes -source ~/.bash_profile -pdone "AWS credentials set in ~/.bash_profile" - -ohai "Installing requirements ..." -# Install Git if not present -if ! command -v git &> /dev/null; then - ohai "Git not found. Installing git ..." - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - if [ -f /etc/os-release ]; then - . /etc/os-release - if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then - ohai "Detected Ubuntu, installing Git..." - if [[ "$DEBUG" == "true" ]]; then - execute_sudo apt-get update -y - execute_sudo apt-get install git -y - else - execute_sudo apt-get update -y > /dev/null 2>&1 - execute_sudo apt-get install git -y > /dev/null 2>&1 - fi - else - warn "Unsupported Linux distribution: $ID" - abort "Cannot install Git automatically" - fi - else - warn "Cannot detect Linux distribution" - abort "Cannot install Git automatically" - fi - else - abort "Unsupported OS type: $OSTYPE" - fi -else - pdone "Git is already installed" -fi - -# TODO: Add error handling for package installations -# TODO: Ensure compatibility with different package managers - -# Check for Rust installation -if ! command -v rustc &> /dev/null; then - ohai "Installing Rust ..." - if [[ "$DEBUG" == "true" ]]; then - execute curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - else - execute curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y > /dev/null 2>&1 - fi - # Add Rust to the PATH for the current session - source $HOME/.cargo/env -fi -pdone "Rust is installed" - -# Install uv if not present -if ! command -v uv &> /dev/null; then - ohai "Installing uv ..." - if [[ "$DEBUG" == "true" ]]; then - execute curl -LsSf https://astral.sh/uv/install.sh | sh - else - execute curl -LsSf https://astral.sh/uv/install.sh | sh > /dev/null 2>&1 - fi - # Add uv to the PATH for the current session - export PATH="$HOME/.cargo/bin:$PATH" -fi -pdone "uv is installed" - -# Check if npm is installed -if ! command -v npm &> /dev/null; then - ohai "Installing npm ..." - if ! command -v node &> /dev/null; then - ohai "Node.js could not be found, installing..." - if ! curl -fsSL https://deb.nodesource.com/setup_18.x | bash; then - abort "Failed to download Node.js setup script" - fi - if ! execute_sudo apt-get install -y nodejs; then - abort "Failed to install Node.js" - fi - fi - if ! curl -L https://www.npmjs.com/install.sh | sh; then - abort "Failed to install npm" - fi -fi -pdone "npm is installed" - -# Install pm2 -if ! command -v pm2 &> /dev/null; then - ohai "Installing pm2 ..." - if [[ "$DEBUG" == "true" ]]; then - execute npm install pm2 -g - else - execute npm install pm2 -g > /dev/null 2>&1 - fi -fi -pdone "pm2 is installed" - -ohai "Installing Boltzmann ..." -# Check if we are inside the boltzmann repository -if git rev-parse --is-inside-work-tree > /dev/null 2>&1; then - REPO_PATH="." -else - if [ ! -d "boltzmann" ]; then - ohai "Cloning boltzmann ..." - execute git clone https://github.com/unconst/boltzmann - REPO_PATH="boltzmann/" - else - REPO_PATH="boltzmann/" - fi -fi -pdone "Boltzmann repository is ready at $REPO_PATH" - -# Install Python 3.12 if not installed -if ! command -v python3.12 &> /dev/null; then - ohai "Installing python3.12 ..." - if [[ "$OSTYPE" == "linux-gnu"* ]]; then - ohai "Detected Linux" - if [ -f /etc/os-release ]; then - . /etc/os-release - if [[ "$ID" == "ubuntu" || "$ID_LIKE" == *"ubuntu"* ]]; then - ohai "Detected Ubuntu, installing Python 3.12..." - if [[ "$DEBUG" == "true" ]]; then - if have_sudo_access; then - execute_sudo add-apt-repository ppa:deadsnakes/ppa -y - else - warn "Skipping add-apt-repository due to lack of sudo access" - fi - execute_sudo apt-get update -y - else - if have_sudo_access; then - execute_sudo add-apt-repository ppa:deadsnakes/ppa -y > /dev/null 2>&1 - else - warn "Skipping add-apt-repository due to lack of sudo access" - fi - execute_sudo apt-get update -y > /dev/null 2>&1 - execute_sudo apt-get install --reinstall python3-apt > /dev/null 2>&1 - execute_sudo apt-get install python3.12 -y > /dev/null 2>&1 - execute_sudo apt-get install python3.12-venv > /dev/null 2>&1 - fi - else - warn "Unsupported Linux distribution: $ID" - abort "Cannot install Python 3.12 automatically" - fi - else - warn "Cannot detect Linux distribution" - abort "Cannot install Python 3.12 automatically" - fi - else - abort "Unsupported OS type: $OSTYPE" - fi -fi -pdone "Python 3.12 is installed" - -# Create a virtual environment if it does not exist -if [ ! -d "$REPO_PATH/venv" ]; then - ohai "Creating virtual environment at $REPO_PATH..." - if [[ "$DEBUG" == "true" ]]; then - execute uv venv "$REPO_PATH/.venv" - else - execute uv venv "$REPO_PATH/.venv" > /dev/null 2>&1 - fi -fi -pdone "Virtual environment is set up at $REPO_PATH" - - -# Activate the virtual environment -ohai "Activating virtual environment ..." -source $REPO_PATH/.venv/bin/activate -pdone "Virtual environment activated" - -ohai "Installing Python requirements ..." -if [[ "$DEBUG" == "true" ]]; then - execute uv pip install -r $REPO_PATH/requirements.txt - execute uv pip install --upgrade cryptography pyOpenSSL -else - execute uv pip install -r $REPO_PATH/requirements.txt > /dev/null 2>&1 - execute uv pip install --upgrade cryptography pyOpenSSL > /dev/null 2>&1 -fi -pdone "Python requirements installed" - -# Check for GPUs -ohai "Checking for GPUs..." -if ! command -v nvidia-smi &> /dev/null; then - warn "nvidia-smi command not found. Please ensure NVIDIA drivers are installed." - NUM_GPUS=0 -else - NUM_GPUS=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) - - if [ "$NUM_GPUS" -gt 0 ]; then - nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | while read -r memory; do - pdone "Found GPU with $((memory / 1024)) GB of memory" - done - else - warn "No GPUs found on this machine." - fi -fi - -# Check system RAM -if command -v free &> /dev/null; then - TOTAL_RAM=$(free -g | awk '/^Mem:/{print $2}') - pdone "System RAM: ${TOTAL_RAM} GB" -else - warn "Cannot determine system RAM. 'free' command not found." -fi - -ohai "Creating wallets ..." -# Create the default key -if ! python3 -c "import bittensor as bt; w = bt.wallet(); print(w.coldkey_file.exists_on_device())" | grep -q "True"; then - execute btcli w new_coldkey --wallet.path ~/.bittensor/wallets --wallet.name default --n-words 12 -fi -pdone "Wallet 'default' is ready" - -# Ensure btcli is installed -if ! command -v btcli &> /dev/null; then - abort "btcli command not found. Please ensure it is installed." -fi - -# Create hotkeys and register them -if [ "$NUM_GPUS" -gt 0 ]; then - for i in $(seq 0 $((NUM_GPUS - 1))); do - # Check if the hotkey file exists on the device - exists_on_device=$(python3 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); print(w.hotkey_file.exists_on_device())" 2>/dev/null) - if [ "$exists_on_device" != "True" ]; then - echo "n" | btcli wallet new_hotkey --wallet.name default --wallet.hotkey C$i --n-words 12 > /dev/null 2>&1; - fi - pdone "Created Hotkey 'C$i'" - - # Check if the hotkey is registered on subnet 220 - is_registered=$(python3 -c "import bittensor as bt; w = bt.wallet(hotkey='C$i'); sub = bt.subtensor('test'); print(sub.is_hotkey_registered_on_subnet(hotkey_ss58=w.hotkey.ss58_address, netuid=220))" 2>/dev/null) - if [[ "$is_registered" != *"True"* ]]; then - ohai "Registering hotkey 'C$i' on subnet 220" - btcli subnet pow_register --wallet.name default --wallet.hotkey C$i --netuid 220 --subtensor.network test --no_prompt > /dev/null 2>&1; - fi - pdone "Registered Hotkey 'C$i' on subnet 220" - done -else - warn "No GPUs found. Skipping hotkey creation." - exit -fi -pdone "All hotkeys registered" - -ohai "Logging into wandb..." -execute wandb login -pdone "wandb is configured" - -# Clean the bucket -ohai "Cleaning bucket $BUCKET..." -if [[ "$DEBUG" == "true" ]]; then - execute python3 $REPO_PATH/tools/clean.py --bucket "$BUCKET" -else - execute python3 $REPO_PATH/tools/clean.py --bucket "$BUCKET" > /dev/null 2>&1 -fi -pdone "Bucket '$BUCKET' cleaned" - -# Close down all previous processes and restart them -if pm2 list | grep -q 'online'; then - ohai "Stopping old pm2 processes..." - pm2 delete all - pdone "Old processes stopped" -fi - -# Start all the processes again -if [ "$NUM_GPUS" -gt 0 ]; then - for i in $(seq 0 $((NUM_GPUS - 1))); do - # Adjust GPU index for zero-based numbering - GPU_INDEX=$i - GPU_MEMORY=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | sed -n "$((i + 1))p") - if [ -z "$GPU_MEMORY" ]; then - warn "Could not get GPU memory for GPU $i" - continue - fi - # Determine batch size based on GPU memory - if [ "$GPU_MEMORY" -ge 80000 ]; then - BATCH_SIZE=6 - elif [ "$GPU_MEMORY" -ge 40000 ]; then - BATCH_SIZE=3 - elif [ "$GPU_MEMORY" -ge 20000 ]; then - BATCH_SIZE=1 - else - BATCH_SIZE=1 - fi - ohai "Starting miner on GPU $GPU_INDEX with batch size $BATCH_SIZE..." - if [[ "$DEBUG" == "true" ]]; then - execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" - else - execute pm2 start "$REPO_PATH/miner.py" --interpreter python3 --name C$i -- --actual_batch_size "$BATCH_SIZE" --wallet.name default --wallet.hotkey C$i --bucket "$BUCKET" --device cuda:$GPU_INDEX --use_wandb --project "$PROJECT" > /dev/null 2>&1 - fi - done -else - warn "No GPUs found. Skipping miner startup." -fi -pdone "All miners started" -pm2 list - -echo "" -pdone "SUCCESS" -echo "" - -# Start logging process 1 -pm2 logs C0 - diff --git a/start.sh b/start.sh deleted file mode 100755 index cc621fc..0000000 --- a/start.sh +++ /dev/null @@ -1,36 +0,0 @@ -# The MIT License (MIT) -# © 2024 Chakana.tech - -# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated -# documentation files (the “Software”), to deal in the Software without restriction, including without limitation -# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, -# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all copies or substantial portions of -# the Software. - -# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO -# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL -# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -# DEALINGS IN THE SOFTWARE. - -# Close down all previous processes and restart them. -pm2 sendSignal SIGINT all -pm2 delete all -# Delete items from bucket -BUCKET=${1:-decis} -PROJECT=${2:-aesop} -python3 tools/clean.py --bucket $BUCKET - -# Start all the processes again. -pm2 start validator.py --interpreter python3 --name V1 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey default --bucket $BUCKET --device cuda:0 --use_wandb --project $PROJECT -pm2 start miner.py --interpreter python3 --name M1 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M1 --bucket $BUCKET --device cuda:1 --use_wandb --project $PROJECT -pm2 start miner.py --interpreter python3 --name M2 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M2 --bucket $BUCKET --device cuda:2 --use_wandb --project $PROJECT -pm2 start miner.py --interpreter python3 --name M3 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M3 --bucket $BUCKET --device cuda:3 --use_wandb --project $PROJECT -pm2 start miner.py --interpreter python3 --name M4 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M4 --bucket $BUCKET --device cuda:5 --use_wandb --random --project $PROJECT -pm2 start miner.py --interpreter python3 --name M5 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M5 --bucket $BUCKET --device cuda:6 --use_wandb --random --project $PROJECT -pm2 start miner.py --interpreter python3 --name M6 -- --actual_batch_size 6 --wallet.name Alice --wallet.hotkey M3 --bucket $BUCKET --device cuda:4 --use_wandb --baseline --project $PROJECT - - - diff --git a/tests/eval.py b/tests/eval.py index 3d4a533..e12191d 100644 --- a/tests/eval.py +++ b/tests/eval.py @@ -28,7 +28,7 @@ import tempfile import traceback import bittensor as bt -from hparams import load_hparams +from boltz.hparams import load_hparams from types import SimpleNamespace from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM diff --git a/tests/legacy_miner.py b/tests/legacy_miner.py index d9103aa..3e33a6a 100644 --- a/tests/legacy_miner.py +++ b/tests/legacy_miner.py @@ -37,9 +37,9 @@ from torch.optim.lr_scheduler import CosineAnnealingLR # Import local files. -from common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.common import * +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader # GPU optimizations. torch.backends.cudnn.benchmark = True diff --git a/tests/legacy_validator.py b/tests/legacy_validator.py index 6a753f4..8aecec1 100644 --- a/tests/legacy_validator.py +++ b/tests/legacy_validator.py @@ -36,9 +36,9 @@ from torch.optim.lr_scheduler import CosineAnnealingLR # Import local files. -from common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.common import * +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader # GPU optimizations. torch.backends.cudnn.benchmark = True diff --git a/validator.py b/validator.py index 29c5b12..905cc82 100644 --- a/validator.py +++ b/validator.py @@ -36,9 +36,9 @@ from torch.optim.lr_scheduler import CosineAnnealingLR # Import local files. -from common import * -from hparams import load_hparams -from dataset import DatasetLoader +from boltz.common import * +from boltz.hparams import load_hparams +from boltz.dataset import DatasetLoader # GPU optimizations. torch.backends.cudnn.benchmark = True From a7120202a68496c1060ca868ae012ed54f29e38c Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Mon, 28 Oct 2024 04:30:10 +0400 Subject: [PATCH 6/9] feat: validator fsdp , move common fsdp utils to fsdp.py --- boltz/fsdp.py | 139 +++++++++++++++++++++++++++++++++++++++++++++----- miner.py | 71 ++++++-------------------- validator.py | 24 +++++++-- 3 files changed, 162 insertions(+), 72 deletions(-) diff --git a/boltz/fsdp.py b/boltz/fsdp.py index b252ac1..abe8587 100644 --- a/boltz/fsdp.py +++ b/boltz/fsdp.py @@ -1,13 +1,23 @@ +from collections import defaultdict +from typing import Callable, List, Tuple + +import os +import torch +import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed.fsdp.wrap import ( - _or_policy, - lambda_auto_wrap_policy, - transformer_auto_wrap_policy, -) +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy +import functools +from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +import torch.distributed as dist +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +import logging +from common import logger def fsdp_auto_wrap_policy(model, transformer_layer_names): - import functools + def lambda_policy_fn(module): if ( @@ -18,14 +28,117 @@ def lambda_policy_fn(module): return True return False - lambda_policy = functools.partial( - lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn - ) + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=set(transformer_layer_names) + transformer_auto_wrap_policy, + transformer_layer_cls=set(transformer_layer_names) ) - auto_wrap_policy = functools.partial( - _or_policy, policies=[lambda_policy, transformer_wrap_policy] - ) + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) return auto_wrap_policy + + +def wrap_model_with_fsdp( + model: nn.Module, + device_id: int, + transformer_layer_names: List[nn.Module], + is_distributed: bool, + device: torch.device +) -> nn.Module: + """ + Wraps a PyTorch model with Fully Sharded Data Parallel (FSDP) if distributed training is enabled. + + Args: + model (nn.Module): The PyTorch model to wrap + device_id (int): The local device ID for FSDP + transformer_layer_names (List[nn.Module]): List of transformer layer classes to wrap + is_distributed (bool): Whether distributed training is enabled + device (torch.device): The device to move the model to + + Returns: + nn.Module: The wrapped model, either with FSDP or moved to device + + Example: + >>> model = LlamaForCausalLM(config) + >>> wrapped_model = wrap_model_with_fsdp( + ... model=model, + ... device_id=local_rank, + ... transformer_layer_names=[LlamaDecoderLayer], + ... is_distributed=dist.is_initialized(), + ... device=torch.device("cuda", local_rank) + ... ) + """ + if is_distributed: + # Create the custom auto wrap policy + auto_wrap_policy = fsdp_auto_wrap_policy( + model=model, + transformer_layer_names=transformer_layer_names + ) + + # Wrap with FSDP + model = FSDP( + model, + device_id=device_id, + auto_wrap_policy=auto_wrap_policy, + sharding_strategy=ShardingStrategy.FULL_SHARD, + ) + logger.info(f"Model wrapped with FSDP on device {device} using custom auto wrap policy.") + else: + # Move to device for single-process execution + model = model.to(device) + logger.info(f"Model moved to device {device}.") + + return model + +def initialize_distributed_training( + device_config: str, +) -> Tuple[int, int, int, torch.device]: + """ + Initializes distributed training configuration and device settings. + + Args: + device_config (str): Device configuration string (e.g. 'cuda:0') + logger (logging.Logger): Logger instance for output messages + + Returns: + Tuple containing: + local_rank (int): Local process rank + global_rank (int): Global process rank + world_size (int): Total number of processes + device (torch.device): Torch device object + + Example: + >>> local_rank, global_rank, world_size, device = initialize_distributed_training( + ... device_config='cuda:0', + ... logger=logging.getLogger() + ... ) + """ + # Initialize distributed training if CUDA available and running in distributed mode + if torch.cuda.is_available() and "LOCAL_RANK" in os.environ: + # Get distributed training parameters from environment + local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Validate GPU availability + num_gpus = torch.cuda.device_count() + if local_rank >= num_gpus: + raise ValueError(f"Local rank {local_rank} exceeds number of available GPUs {num_gpus}.") + + # Configure device + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + + # Initialize process group + dist.init_process_group(backend='nccl') + logger.info(f"Distributed training initialized on rank {global_rank} out of {world_size} processes.") + + else: + # Single process execution settings + local_rank = 0 + global_rank = 0 + world_size = 1 + device = torch.device(device_config) + logger.warning("Distributed training is not initialized. Running on a single process.") + + return local_rank, global_rank, world_size, device diff --git a/miner.py b/miner.py index 966890a..70108aa 100644 --- a/miner.py +++ b/miner.py @@ -34,16 +34,13 @@ from transformers import LlamaForCausalLM from torch.optim.lr_scheduler import CosineAnnealingLR import torch.distributed as dist -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ShardingStrategy -from functools import partial from transformers.models.llama.modeling_llama import LlamaDecoderLayer # Import local files. from boltz.common import * from boltz.hparams import load_hparams from boltz.dataset import DatasetLoader -from boltz.fsdp import fsdp_auto_wrap_policy +from boltz.fsdp import * # GPU optimizations. torch.backends.cudnn.benchmark = True @@ -111,34 +108,8 @@ def __init__(self): wandb.init(project=self.config.project, resume='allow', name=f'M{self.uid}', config=self.config) # Initialize distributed training - if torch.cuda.is_available() and "LOCAL_RANK" in os.environ: - # torchrun provides LOCAL_RANK, RANK, and WORLD_SIZE environment variables - self.local_rank = int(os.environ["LOCAL_RANK"]) - self.global_rank = int(os.environ["RANK"]) - self.world_size = int(os.environ["WORLD_SIZE"]) - - num_gpus = torch.cuda.device_count() - if self.local_rank >= num_gpus: - raise ValueError(f"Local rank {self.local_rank} exceeds number of available GPUs {num_gpus}.") - - # Set the device for this process - torch.cuda.set_device(self.local_rank) - self.device = torch.device("cuda", self.local_rank) - - # Initialize the process group - dist.init_process_group(backend='nccl') - logger.info(f"Distributed training initialized on rank {self.global_rank} out of {self.world_size} processes.") - else: - # Single process execution - self.local_rank = 0 - self.global_rank = 0 - self.world_size = 1 - self.device = torch.device(self.config.device) - logger.warning("Distributed training is not initialized. Running on a single process.") - - - # Identify if the current process is the master (rank 0). - is_master = self.global_rank == 0 + self.local_rank, self.global_rank, self.world_size, self.device = initialize_distributed_training( + device_config=self.config.device) # Init model. logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) @@ -146,30 +117,20 @@ def __init__(self): torch.manual_seed(42); np.random.seed(42); random.seed(42) self.model = LlamaForCausalLM(config=self.hparams.model_config) - # Wrap the model with FSDP if distributed training is initialized - if dist.is_initialized(): - # Define the transformer layer names to wrap - transformer_layer_names = [LlamaDecoderLayer] - - # Create the custom auto wrap policy - auto_wrap_policy = fsdp_auto_wrap_policy( - model=self.model, - transformer_layer_names=transformer_layer_names - ) + # Wrap the model with Fully Sharded Data Parallel (FSDP) if distributed training is initialized + self.model = wrap_model_with_fsdp( + model=self.model, + device_id=self.local_rank, + transformer_layer_names=[LlamaDecoderLayer], + is_distributed=dist.is_initialized(), + device=self.device + ) + init_device = "cuda" + self.model.to_empty(device=init_device) + self.model.init_weights() + self.model.train() - # Wrap the model with FSDP using the custom auto wrap policy - self.model = FSDP( - self.model, - device_id=self.local_rank, - auto_wrap_policy=auto_wrap_policy, - sharding_strategy=ShardingStrategy.FULL_SHARD, - ) - logger.info(f"Model wrapped with FSDP on device {self.device} using custom auto wrap policy.") - else: - # Move the model to the device for single-process execution - self.model.to(self.device) - logger.info(f"Model moved to device {self.device}.") - self.model.train() + self.model_parts = [self.model] self.optimizer = optim.AdamW( self.model.parameters(), lr=self.hparams.learning_rate, # Peak learning rate diff --git a/validator.py b/validator.py index 905cc82..1614946 100644 --- a/validator.py +++ b/validator.py @@ -34,11 +34,13 @@ from dotenv import dotenv_values from transformers import LlamaForCausalLM from torch.optim.lr_scheduler import CosineAnnealingLR +from transformers.models.llama.modeling_llama import LlamaDecoderLayer # Import local files. from boltz.common import * from boltz.hparams import load_hparams from boltz.dataset import DatasetLoader +from boltz.fsdp import * # GPU optimizations. torch.backends.cudnn.benchmark = True @@ -51,8 +53,8 @@ class Validator: def config(): parser = argparse.ArgumentParser(description='Validator script') parser.add_argument('--project', type=str, default='aesop2', help='Optional wandb project name') - parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') - parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') + parser.add_argument('--netuid', type=int, default=223, help='Bittensor network UID.') + parser.add_argument('--bucket', type=str, default='cont2', help='S3 bucket name') parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') @@ -100,7 +102,10 @@ def __init__(self): if run.name == f'V{self.uid}': logger.info(f'Deleting old run: {run}'); run.delete() except: pass - wandb.init(project=self.config.project, resume='allow', name=f'V{self.uid}', config=self.config) + + # Initialize distributed training + self.local_rank, self.global_rank, self.world_size, self.device = initialize_distributed_training( + device_config=self.config.device) # Init model. logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40) @@ -108,7 +113,18 @@ def __init__(self): torch.manual_seed(42); np.random.seed(42); random.seed(42) self.model = LlamaForCausalLM(config=self.hparams.model_config) # self.model = LlamaForCausalLM.from_pretrained('TinyLlama/TinyLlama_v1.1') - self.model.to(self.config.device) + + # Wrap the model with Fully Sharded Data Parallel (FSDP) if distributed training is initialized + self.model = wrap_model_with_fsdp( + model=self.model, + device_id=self.local_rank, + transformer_layer_names=[LlamaDecoderLayer], + is_distributed=dist.is_initialized(), + device=self.device + ) + init_device = "cuda" + self.model.to_empty(device=init_device) + self.model.init_weights() self.model.eval() # Init buckets. From 81bd53d8914aecd307cb69b099514f2694c7b220 Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Mon, 28 Oct 2024 04:38:54 +0400 Subject: [PATCH 7/9] chore: tidy names --- miner.py | 4 ++-- scripts/start_distributed.sh | 4 ++-- validator.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/miner.py b/miner.py index 70108aa..6bac906 100644 --- a/miner.py +++ b/miner.py @@ -53,8 +53,8 @@ class Miner: def config(): parser = argparse.ArgumentParser(description='Miner script') parser.add_argument('--project', type=str, default='aesop2', help='Optional wandb project name') - parser.add_argument('--netuid', type=int, default=223, help='Bittensor network UID.') - parser.add_argument('--bucket', type=str, default='cont2', help='S3 bucket name') + parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') + parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') diff --git a/scripts/start_distributed.sh b/scripts/start_distributed.sh index 21e39b6..46764e5 100644 --- a/scripts/start_distributed.sh +++ b/scripts/start_distributed.sh @@ -20,7 +20,7 @@ pm2 sendSignal SIGINT all pm2 delete all # Delete items from bucket -BUCKET=${1:-cont2} +BUCKET=${1:-decis} PROJECT=${2:-aesop} # python3 tools/clean.py --bucket $BUCKET @@ -40,4 +40,4 @@ pm2 start "torchrun --nproc_per_node=${NGPU} \ --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ --local-ranks-filter ${LOG_RANK} --role rank \ --tee 3 miner.py -- --actual_batch_size 6 --wallet.name Bistro \ - --wallet.hotkey M111 --bucket $BUCKET --device cuda:1 --use_wandb --project $PROJECT --debug" --name M1 --interpreter none +--wallet.hotkey M111 --bucket $BUCKET --device cuda:1 --use_wandb --project $PROJECT --debug" --name M1 --interpreter none diff --git a/validator.py b/validator.py index 1614946..e743fd0 100644 --- a/validator.py +++ b/validator.py @@ -53,8 +53,8 @@ class Validator: def config(): parser = argparse.ArgumentParser(description='Validator script') parser.add_argument('--project', type=str, default='aesop2', help='Optional wandb project name') - parser.add_argument('--netuid', type=int, default=223, help='Bittensor network UID.') - parser.add_argument('--bucket', type=str, default='cont2', help='S3 bucket name') + parser.add_argument('--netuid', type=int, default=220, help='Bittensor network UID.') + parser.add_argument('--bucket', type=str, default='decis', help='S3 bucket name') parser.add_argument('--actual_batch_size', type=int, default=8, help='Training batch size per accumulation.') parser.add_argument('--device', type=str, default='cuda', help='Device to use for training (e.g., cpu or cuda)') parser.add_argument('--use_wandb', action='store_true', help='Use Weights and Biases for logging') From a71291468ef9aeb905e364dc711ca2d00bc71c3c Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Tue, 29 Oct 2024 15:12:23 +0400 Subject: [PATCH 8/9] fix: patch apply indices to broadcast properly --- boltz/common.py | 69 ++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/boltz/common.py b/boltz/common.py index 482e60e..11d7c74 100644 --- a/boltz/common.py +++ b/boltz/common.py @@ -114,40 +114,12 @@ async def get_slices(filename: str, device: str) -> Dict[str, torch.Tensor]: weights_only=True, ) - async def apply_slices_to_model( model: nn.Module, window: int, seed: str, compression: int, key: str = "slice" ) -> List[str]: - """ - Applies slices from a specific window to the given FSDP model. - - Args: - model (torch.nn.Module): The FSDP-wrapped PyTorch model to which the slices will be applied. - window (int): The window identifier. - seed (str): The seed used for generating indices. - compression (int): The compression factor. - key (str): The key used to identify the slices. - - Returns: - List[str]: A list of all the slice files that were applied. - - Example: - slice_files = await apply_slices_to_model( - model=my_fsdp_model, - window=42, - seed="1234", - compression=10, - key='slice', - ) - - Notes: - - This function is adapted to work with FSDP. It ensures that all ranks participate - in collective operations required by FSDP to prevent hangs. - - Exception handling is added to ensure that any errors are caught, and all ranks exit gracefully. - """ rank = dist.get_rank() world_size = dist.get_world_size() - logger.debug(f"Rank {rank}: Starting apply_slices_to_model") + logger.debug(f"Rank {rank}: Starting apply_slices_to_model for window {window}") # Get indices associated with the window (all ranks must participate) try: @@ -164,16 +136,26 @@ async def apply_slices_to_model( if rank == 0: try: slice_files: List[str] = await load_files_for_window(window=window, key=key) - logger.debug(f"Rank {rank}: Loaded {len(slice_files)} slice files") + logger.debug(f"Rank {rank}: Loaded {len(slice_files)} slice files for window {window}") except Exception as e: logger.exception(f"Rank {rank}: Failed to load slice files: {e}") slice_files = [] else: - slice_files = [] + slice_files = None # Placeholder for other ranks + + # Broadcast slice_files from rank 0 to all ranks + try: + slice_files_list = [slice_files] if rank == 0 else [None] + dist.broadcast_object_list(slice_files_list, src=0) + slice_files = slice_files_list[0] + logger.debug(f"Rank {rank}: Received slice_files: {slice_files}") + except Exception as e: + logger.exception(f"Rank {rank}: Failed to broadcast slice files: {e}") + sys.exit(1) # Ensure all ranks exit if not slice_files: logger.warning(f"Rank {rank}: No slice files to process for window {window}") - return slice_files # Early return, but all ranks have synchronized here + return slice_files # All ranks return here synchronously # Initialize dictionaries to keep track of sums and counts param_sums: Dict[str, torch.Tensor] = {} @@ -209,10 +191,27 @@ async def apply_slices_to_model( else: logger.warning(f"Rank {rank}: No slices applied for parameter {name}") - # All ranks participate in updating the model parameters + # Broadcast param_sums and slices_per_param from rank 0 to all ranks + try: + if rank == 0: + data_to_broadcast = (param_sums, slices_per_param) + else: + data_to_broadcast = None + + data_list = [data_to_broadcast] + dist.broadcast_object_list(data_list, src=0) + if rank != 0: + param_sums, slices_per_param = data_list[0] + + logger.debug(f"Rank {rank}: Received param_sums and slices_per_param") + except Exception as e: + logger.exception(f"Rank {rank}: Failed to broadcast parameter sums: {e}") + sys.exit(1) # Ensure all ranks exit + + # All ranks proceed to update the model parameters try: # Retrieve the full state_dict (all ranks must participate) - cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): state_dict: Dict[str, torch.Tensor] = model.state_dict() @@ -231,7 +230,7 @@ async def apply_slices_to_model( logger.trace(f"Rank {rank}: No updates applied to parameter {name}") # Broadcast the updated state_dict from rank 0 to all other ranks - state_dict_list = [state_dict] + state_dict_list = [state_dict] if rank == 0 else [None] dist.broadcast_object_list(state_dict_list, src=0) state_dict = state_dict_list[0] logger.debug(f"Rank {rank}: Received updated state_dict from broadcast") From b4ddec9cbfb440b67d6b799f6448ca87153f1d22 Mon Sep 17 00:00:00 2001 From: distributedstatemachine Date: Tue, 29 Oct 2024 19:28:52 +0400 Subject: [PATCH 9/9] hparams --- hparams.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hparams.json b/hparams.json index 35735d9..3b16da1 100644 --- a/hparams.json +++ b/hparams.json @@ -3,11 +3,11 @@ "compression": 100, "sequence_length": 2048, "tokenizer_name": "togethercomputer/LLaMA-2-7B-32K", - "num_hidden_layers": 16, - "hidden_size": 2048, + "num_hidden_layers": 40, + "hidden_size": 5120, "intermediate_size": 8192, - "num_attention_heads": 8, - "num_key_value_heads": 8, + "num_attention_heads": 40, + "num_key_value_heads": 40, "activation_function": "swiGLU", "max_position_embeddings": 2048, "mixed_precision_param": "bfloat16",