diff --git a/.gitattributes b/.gitattributes index ab26661da..6b5c882ea 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,5 @@ *.mp3 filter=lfs diff=lfs merge=lfs -text *.png filter=lfs diff=lfs merge=lfs -text +*.psd filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text diff --git a/exo/helpers.py b/exo/helpers.py index ff0205f00..4833c9226 100644 --- a/exo/helpers.py +++ b/exo/helpers.py @@ -8,6 +8,7 @@ import psutil import uuid from scapy.all import get_if_addr, get_if_list +from scapy.arch.windows import get_windows_if_list # Windows-specific import re import subprocess from pathlib import Path @@ -15,6 +16,8 @@ import json from concurrent.futures import ThreadPoolExecutor import traceback +import netifaces +from typing import List, Tuple DEBUG = int(os.getenv("DEBUG", default="0")) DEBUG_DISCOVERY = int(os.getenv("DEBUG_DISCOVERY", default="0")) @@ -230,22 +233,67 @@ def pretty_print_bytes_per_second(bytes_per_second: int) -> str: return f"{bytes_per_second / (1024 ** 4):.2f} TB/s" -def get_all_ip_addresses_and_interfaces(): +def get_all_ip_addresses_and_interfaces() -> List[Tuple[str, str]]: + """ + Get all active IP addresses and their corresponding interfaces. + Excludes loopback, non-routable, and inactive interfaces. + """ ip_addresses = [] - for interface in get_if_list(): - try: - ip = get_if_addr(interface) - if ip.startswith("0.0."): continue - simplified_interface = re.sub(r'^\\Device\\NPF_', '', interface) - ip_addresses.append((ip, simplified_interface)) - except: - if DEBUG >= 1: print(f"Failed to get IP address for interface {interface}") - if DEBUG >= 1: traceback.print_exc() + + # Exclude these IP prefixes (non-routable or problematic) + excluded_prefixes = ( + "127.", # Loopback + "169.254.", # Link-local + "192.168.137.", # Hyper-V, VMware, etc. + "0.0.0.0", # Invalid + ) + + # Get all network interfaces + interfaces = netifaces.interfaces() + + for interface in interfaces: + try: + # Get IPv4 addresses for the interface + addrs = netifaces.ifaddresses(interface).get(netifaces.AF_INET, []) + for addr in addrs: + ip = addr.get("addr") + if not ip: + continue + + # Skip excluded IPs + if any(ip.startswith(prefix) for prefix in excluded_prefixes): + if DEBUG >= 1: + print(f"Skipping excluded IP: {ip} on interface {interface}") + continue + + # On Windows, check if the interface is active + if psutil.WINDOWS: + try: + # Use psutil to check interface status + if_stats = psutil.net_if_stats().get(interface) + if if_stats and not if_stats.isup: + if DEBUG >= 1: + print(f"Skipping inactive interface: {interface}") + continue + except Exception as e: + if DEBUG >= 1: + print(f"Error checking interface status for {interface}: {e}") + + # Add the IP and interface to the list + ip_addresses.append((ip, interface)) + + except Exception as e: + if DEBUG >= 1: + print(f"Error processing interface {interface}: {e}") + traceback.print_exc() + + # If no valid IPs are found, default to localhost if not ip_addresses: - if DEBUG >= 1: print("Failed to get any IP addresses. Defaulting to localhost.") - return [("localhost", "lo")] - return list(set(ip_addresses)) + if DEBUG >= 1: + print("No valid IP addresses found. Defaulting to localhost.") + return [("127.0.0.1", "lo")] + return list(set(ip_addresses)) # Remove duplicates async def get_macos_interface_type(ifname: str) -> Optional[Tuple[int, str]]: diff --git a/exo/main.py b/exo/main.py index db91251c8..fc4c47ce3 100644 --- a/exo/main.py +++ b/exo/main.py @@ -28,34 +28,47 @@ from exo.inference.tokenizers import resolve_tokenizer from exo.models import build_base_shard, get_repo from exo.viz.topology_viz import TopologyViz -import uvloop import concurrent.futures -import resource import psutil -# TODO: figure out why this is happening +# Environment variable configuration os.environ["GRPC_VERBOSITY"] = "error" os.environ["TRANSFORMERS_VERBOSITY"] = "error" os.environ["TOKENIZERS_PARALLELISM"] = "true" -# Configure uvloop for maximum performance -def configure_uvloop(): - uvloop.install() - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) +# Configure event loop based on platform +def configure_event_loop(): + if platform.system() == "Windows": + # Use the default ProactorEventLoop on Windows + loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + # Use uvloop on non-Windows platforms + import uvloop + uvloop.install() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - # Increase file descriptor limits on Unix systems + # Set file descriptor limits on Unix systems if not psutil.WINDOWS: - soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) - try: resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) - except ValueError: - try: resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard)) - except ValueError: pass - - loop.set_default_executor(concurrent.futures.ThreadPoolExecutor(max_workers=min(32, (os.cpu_count() or 1) * 4))) + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) + except ValueError: + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (8192, hard)) + except ValueError: + pass + + # Configure thread pool executor + loop.set_default_executor( + concurrent.futures.ThreadPoolExecutor( + max_workers=min(32, (os.cpu_count() or 1) * 4) + ) + ) return loop -# parse args +# Parse command line arguments parser = argparse.ArgumentParser(description="Initialize GRPC Discovery") parser.add_argument("command", nargs="?", choices=["run", "eval", "train"], help="Command to run") parser.add_argument("model_name", nargs="?", help="Model name to run") @@ -92,7 +105,6 @@ def configure_uvloop(): parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)") parser.add_argument("--system-prompt", type=str, default=None, help="System prompt for the ChatGPT API") args = parser.parse_args() -print(f"Selected inference engine: {args.inference_engine}") print_yellow_exo() @@ -107,286 +119,329 @@ def configure_uvloop(): print(f"Using inference engine: {inference_engine.__class__.__name__} with shard downloader: {shard_downloader.__class__.__name__}") if args.node_port is None: - args.node_port = find_available_port(args.node_host) - if DEBUG >= 1: print(f"Using available port: {args.node_port}") + args.node_port = find_available_port(args.node_host) + if DEBUG >= 1: + print(f"Using available port: {args.node_port}") args.node_id = args.node_id or get_or_create_node_id() chatgpt_api_endpoints = [f"http://{ip}:{args.chatgpt_api_port}/v1/chat/completions" for ip, _ in get_all_ip_addresses_and_interfaces()] web_chat_urls = [f"http://{ip}:{args.chatgpt_api_port}" for ip, _ in get_all_ip_addresses_and_interfaces()] + if DEBUG >= 0: - print("Chat interface started:") - for web_chat_url in web_chat_urls: - print(f" - {terminal_link(web_chat_url)}") - print("ChatGPT API endpoint served at:") - for chatgpt_api_endpoint in chatgpt_api_endpoints: - print(f" - {terminal_link(chatgpt_api_endpoint)}") + print("Chat interface started:") + for web_chat_url in web_chat_urls: + print(f" - {terminal_link(web_chat_url)}") + print("ChatGPT API endpoint served at:") + for chatgpt_api_endpoint in chatgpt_api_endpoints: + print(f" - {terminal_link(chatgpt_api_endpoint)}") # Convert node-id-filter and interface-type-filter to lists if provided allowed_node_ids = args.node_id_filter.split(',') if args.node_id_filter else None allowed_interface_types = args.interface_type_filter.split(',') if args.interface_type_filter else None +# Initialize discovery based on selected module if args.discovery_module == "udp": - discovery = UDPDiscovery( - args.node_id, - args.node_port, - args.listen_port, - args.broadcast_port, - lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), - discovery_timeout=args.discovery_timeout, - allowed_node_ids=allowed_node_ids, - allowed_interface_types=allowed_interface_types - ) + discovery = UDPDiscovery( + args.node_id, + args.node_port, + args.listen_port, + args.broadcast_port, + lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), + discovery_timeout=args.discovery_timeout, + allowed_node_ids=allowed_node_ids, + allowed_interface_types=allowed_interface_types + ) elif args.discovery_module == "tailscale": - discovery = TailscaleDiscovery( - args.node_id, - args.node_port, - lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), - discovery_timeout=args.discovery_timeout, - tailscale_api_key=args.tailscale_api_key, - tailnet=args.tailnet_name, - allowed_node_ids=allowed_node_ids - ) + discovery = TailscaleDiscovery( + args.node_id, + args.node_port, + lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities), + discovery_timeout=args.discovery_timeout, + tailscale_api_key=args.tailscale_api_key, + tailnet=args.tailnet_name, + allowed_node_ids=allowed_node_ids + ) elif args.discovery_module == "manual": - if not args.discovery_config_path: - raise ValueError(f"--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.") - discovery = ManualDiscovery(args.discovery_config_path, args.node_id, create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities)) + if not args.discovery_config_path: + raise ValueError("--discovery-config-path is required when using manual discovery. Please provide a path to a config json file.") + discovery = ManualDiscovery( + args.discovery_config_path, + args.node_id, + create_peer_handle=lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities) + ) + topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None + +# Initialize node and server node = Node( - args.node_id, - None, - inference_engine, - discovery, - shard_downloader, - partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), - max_generate_tokens=args.max_generate_tokens, - topology_viz=topology_viz, - default_sample_temperature=args.default_temp + args.node_id, + None, + inference_engine, + discovery, + shard_downloader, + partitioning_strategy=RingMemoryWeightedPartitioningStrategy(), + max_generate_tokens=args.max_generate_tokens, + topology_viz=topology_viz, + default_sample_temperature=args.default_temp ) + server = GRPCServer(node, args.node_host, args.node_port) node.server = server + +# Initialize ChatGPT API api = ChatGPTAPI( - node, - node.inference_engine.__class__.__name__, - response_timeout=args.chatgpt_api_response_timeout, - on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None, - default_model=args.default_model, - system_prompt=args.system_prompt + node, + node.inference_engine.__class__.__name__, + response_timeout=args.chatgpt_api_response_timeout, + on_chat_completion_request=lambda req_id, __, prompt: topology_viz.update_prompt(req_id, prompt) if topology_viz else None, + default_model=args.default_model, + system_prompt=args.system_prompt ) + +# Token output buffering buffered_token_output = {} + def update_topology_viz(req_id, tokens, __): - if not topology_viz: return - if not node.inference_engine.shard: return - if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': return - if req_id in buffered_token_output: buffered_token_output[req_id].extend(tokens) - else: buffered_token_output[req_id] = tokens - topology_viz.update_prompt_output(req_id, node.inference_engine.tokenizer.decode(buffered_token_output[req_id])) + if not topology_viz: + return + if not node.inference_engine.shard: + return + if node.inference_engine.shard.model_id == 'stable-diffusion-2-1-base': + return + if req_id in buffered_token_output: + buffered_token_output[req_id].extend(tokens) + else: + buffered_token_output[req_id] = tokens + topology_viz.update_prompt_output(req_id, node.inference_engine.tokenizer.decode(buffered_token_output[req_id])) + node.on_token.register("update_topology_viz").on_next(update_topology_viz) + def update_prompt_viz(request_id, opaque_status: str): - if not topology_viz: return - try: - status = json.loads(opaque_status) - if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return - topology_viz.update_prompt(request_id, status.get("prompt", "corrupted prompt (this should never happen)")) - except Exception as e: - if DEBUG >= 2: - print(f"Failed to update prompt viz: {e}") - traceback.print_exc() + if not topology_viz: + return + try: + status = json.loads(opaque_status) + if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": + return + topology_viz.update_prompt(request_id, status.get("prompt", "corrupted prompt (this should never happen)")) + except Exception as e: + if DEBUG >= 2: + print(f"Failed to update prompt viz: {e}") + traceback.print_exc() + node.on_opaque_status.register("update_prompt_viz").on_next(update_prompt_viz) def preemptively_load_shard(request_id: str, opaque_status: str): - try: - status = json.loads(opaque_status) - if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": return - current_shard = node.get_current_shard(Shard.from_dict(status.get("shard"))) - if DEBUG >= 2: print(f"Preemptively starting download for {current_shard}") - asyncio.create_task(node.inference_engine.ensure_shard(current_shard)) - except Exception as e: - if DEBUG >= 2: - print(f"Failed to preemptively start download: {e}") - traceback.print_exc() + try: + status = json.loads(opaque_status) + if status.get("type") != "node_status" or status.get("status") != "start_process_prompt": + return + current_shard = node.get_current_shard(Shard.from_dict(status.get("shard"))) + if DEBUG >= 2: + print(f"Preemptively starting download for {current_shard}") + asyncio.create_task(node.inference_engine.ensure_shard(current_shard)) + except Exception as e: + if DEBUG >= 2: + print(f"Failed to preemptively start download: {e}") + traceback.print_exc() + node.on_opaque_status.register("preemptively_load_shard").on_next(preemptively_load_shard) +# Progress event handling last_events: dict[str, tuple[float, RepoProgressEvent]] = {} + def throttled_broadcast(shard: Shard, event: RepoProgressEvent): - global last_events - current_time = time.time() - if event.status == "not_started": return - last_event = last_events.get(shard.model_id) - if last_event and last_event[1].status == "complete" and event.status == "complete": return - if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return - last_events[shard.model_id] = (current_time, event) - asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()}))) + global last_events + current_time = time.time() + if event.status == "not_started": + return + last_event = last_events.get(shard.model_id) + if last_event and last_event[1].status == "complete" and event.status == "complete": + return + if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: + return + last_events[shard.model_id] = (current_time, event) + asyncio.create_task(node.broadcast_opaque_status("", json.dumps({ + "type": "download_progress", + "node_id": node.id, + "progress": event.to_dict() + }))) + shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast) async def run_model_cli(node: Node, model_name: str, prompt: str): - inference_class = node.inference_engine.__class__.__name__ - shard = build_base_shard(model_name, inference_class) - if not shard: - print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}") - return - tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class)) - request_id = str(uuid.uuid4()) - callback_id = f"cli-wait-response-{request_id}" - callback = node.on_token.register(callback_id) - if topology_viz: - topology_viz.update_prompt(request_id, prompt) - prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) - - try: - print(f"Processing prompt: {prompt}") - await node.process_prompt(shard, prompt, request_id=request_id) + inference_class = node.inference_engine.__class__.__name__ + shard = build_base_shard(model_name, inference_class) + if not shard: + print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}") + return + tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class)) + request_id = str(uuid.uuid4()) + callback_id = f"cli-wait-response-{request_id}" + callback = node.on_token.register(callback_id) + if topology_viz: + topology_viz.update_prompt(request_id, prompt) + prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) - tokens = [] - def on_token(_request_id, _tokens, _is_finished): - tokens.extend(_tokens) - return _request_id == request_id and _is_finished - await callback.wait(on_token, timeout=300) - - print("\nGenerated response:") - print(tokenizer.decode(tokens)) - except Exception as e: - print(f"Error processing prompt: {str(e)}") - traceback.print_exc() - finally: - node.on_token.deregister(callback_id) + try: + print(f"Processing prompt: {prompt}") + await node.process_prompt(shard, prompt, request_id=request_id) -def clean_path(path): - """Clean and resolve path""" - if path.startswith("Optional("): - path = path.strip('Optional("').rstrip('")') - return os.path.expanduser(path) + tokens = [] + def on_token(_request_id, _tokens, _is_finished): + tokens.extend(_tokens) + return _request_id == request_id and _is_finished + await callback.wait(on_token, timeout=300) + + print("\nGenerated response:") + print(tokenizer.decode(tokens)) + except Exception as e: + print(f"Error processing prompt: {str(e)}") + traceback.print_exc() + finally: + node.on_token.deregister(callback_id) async def hold_outstanding(node: Node): - while node.outstanding_requests: - await asyncio.sleep(.5) - return + while node.outstanding_requests: + await asyncio.sleep(.5) + return async def run_iter(node: Node, shard: Shard, train: bool, data, batch_size=1): - losses = [] - tokens = [] - for batch in tqdm(iterate_batches(data, batch_size), total=len(data) // batch_size): - _, _, lengths = batch - losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=train))) - tokens.append(np.sum(lengths)) - total_tokens = np.sum(tokens) - total_loss = np.sum(losses) / total_tokens + losses = [] + tokens = [] + for batch in tqdm(iterate_batches(data, batch_size), total=len(data) // batch_size): + _, _, lengths = batch + losses.append(np.sum(lengths * await node.enqueue_example(shard, *batch, train=train))) + tokens.append(np.sum(lengths)) + total_tokens = np.sum(tokens) + total_loss = np.sum(losses) / total_tokens - return total_loss, total_tokens + return total_loss, total_tokens async def eval_model_cli(node: Node, model_name, dataloader, batch_size, num_batches=-1): - inference_class = node.inference_engine.__class__.__name__ - shard = build_base_shard(model_name, inference_class) - if not shard: - print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}") - return - tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class)) - train, val, test = dataloader(tokenizer.encode) - print(f"Evaluating {len(test)} examples with batch_size {batch_size}") - loss, tokens = await run_iter(node, shard, False, test, batch_size) - print(f"total | {loss=}, {tokens=}") - print("Waiting for outstanding tasks") - await hold_outstanding(node) + inference_class = node.inference_engine.__class__.__name__ + shard = build_base_shard(model_name, inference_class) + if not shard: + print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}") + return + tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class)) + train, val, test = dataloader(tokenizer.encode) + print(f"Evaluating {len(test)} examples with batch_size {batch_size}") + loss, tokens = await run_iter(node, shard, False, test, batch_size) + print(f"total | {loss=}, {tokens=}") + print("Waiting for outstanding tasks") + await hold_outstanding(node) async def train_model_cli(node: Node, model_name, dataloader, batch_size, iters, save_interval=0, checkpoint_dir=None): - inference_class = node.inference_engine.__class__.__name__ - shard = build_base_shard(model_name, inference_class) - if not shard: - print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}") - return - tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class)) - train, val, test = dataloader(tokenizer.encode) - print(f"Training on {len(train)} examples with batch_size {batch_size} for {iters} epochs") - for i in tqdm(range(3)): - await asyncio.sleep(1) - for epoch in range(iters): - loss, tokens = await run_iter(node, shard, True, train, batch_size) - print(f"epoch {epoch + 1}/{iters}\t| loss: {loss}, tokens: {tokens}") - if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0 and checkpoint_dir is not None: - await node.coordinate_save(shard, epoch, checkpoint_dir) - await hold_outstanding(node) - await hold_outstanding(node) + inference_class = node.inference_engine.__class__.__name__ + shard = build_base_shard(model_name, inference_class) + if not shard: + print(f"Error: Unsupported model '{model_name}' for inference engine {inference_class}") + return + tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class)) + train, val, test = dataloader(tokenizer.encode) + print(f"Training on {len(train)} examples with batch_size {batch_size} for {iters} epochs") + for i in tqdm(range(3)): + await asyncio.sleep(1) + for epoch in range(iters): + loss, tokens = await run_iter(node, shard, True, train, batch_size) + print(f"epoch {epoch + 1}/{iters}\t| loss: {loss}, tokens: {tokens}") + if save_interval > 0 and epoch > 0 and (epoch % save_interval) == 0 and checkpoint_dir is not None: + await node.coordinate_save(shard, epoch, checkpoint_dir) + await hold_outstanding(node) + await hold_outstanding(node) + +def clean_path(path): + """Clean and resolve path""" + if path.startswith("Optional("): + path = path.strip('Optional("').rstrip('")') + return os.path.expanduser(path) async def check_exo_home(): - home, has_read, has_write = await ensure_exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access() - if DEBUG >= 1: print(f"exo home directory: {home}") - print(f"{has_read=}, {has_write=}") - if not has_read or not has_write: - print(f""" - WARNING: Limited permissions for exo home directory: {home}. - This may prevent model downloads from working correctly. - {"❌ No read access" if not has_read else ""} - {"❌ No write access" if not has_write else ""} - """) + home, has_read, has_write = await ensure_exo_home(), await has_exo_home_read_access(), await has_exo_home_write_access() + if DEBUG >= 1: + print(f"exo home directory: {home}") + print(f"{has_read=}, {has_write=}") + if not has_read or not has_write: + print(f""" + WARNING: Limited permissions for exo home directory: {home}. + This may prevent model downloads from working correctly. + {"❌ No read access" if not has_read else ""} + {"❌ No write access" if not has_write else ""} + """) async def main(): - loop = asyncio.get_running_loop() - - try: await check_exo_home() - except Exception as e: print(f"Error checking exo home directory: {e}") + loop = asyncio.get_running_loop() - if not args.models_seed_dir is None: try: - models_seed_dir = clean_path(args.models_seed_dir) - await seed_models(models_seed_dir) + await check_exo_home() except Exception as e: - print(f"Error seeding models: {e}") + print(f"Error checking exo home directory: {e}") + + if not args.models_seed_dir is None: + try: + models_seed_dir = clean_path(args.models_seed_dir) + await seed_models(models_seed_dir) + except Exception as e: + print(f"Error seeding models: {e}") + + def restore_cursor(): + if platform.system() != "Windows": + os.system("tput cnorm") # Show cursor + + # Restore the cursor when the program exits + atexit.register(restore_cursor) + + # Use a more direct approach to handle signals + def handle_exit(): + asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server)) - def restore_cursor(): if platform.system() != "Windows": - os.system("tput cnorm") # Show cursor - - # Restore the cursor when the program exits - atexit.register(restore_cursor) - - # Use a more direct approach to handle signals - def handle_exit(): - asyncio.ensure_future(shutdown(signal.SIGTERM, loop, node.server)) - - if platform.system() != "Windows": - for s in [signal.SIGINT, signal.SIGTERM]: - loop.add_signal_handler(s, handle_exit) - - await node.start(wait_for_peers=args.wait_for_peers) - - if args.command == "run" or args.run_model: - model_name = args.model_name or args.run_model - if not model_name: - print("Error: Model name is required when using 'run' command or --run-model") - return - await run_model_cli(node, model_name, args.prompt) - elif args.command == "eval" or args.command == 'train': - model_name = args.model_name - dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item) - , loadline=lambda line: json.loads(line).get("text","")) - if args.command == 'eval': - if not model_name: - print("Error: Much like a human, I can't evaluate anything without a model") - return - await eval_model_cli(node, model_name, dataloader, args.batch_size) + for s in [signal.SIGINT, signal.SIGTERM]: + loop.add_signal_handler(s, handle_exit) + + await node.start(wait_for_peers=args.wait_for_peers) + + if args.command == "run" or args.run_model: + model_name = args.model_name or args.run_model + if not model_name: + print("Error: Model name is required when using 'run' command or --run-model") + return + await run_model_cli(node, model_name, args.prompt) + elif args.command == "eval" or args.command == 'train': + model_name = args.model_name + dataloader = lambda tok: load_dataset(args.data, preprocess=lambda item: tok(item), + loadline=lambda line: json.loads(line).get("text", "")) + if args.command == 'eval': + if not model_name: + print("Error: Much like a human, I can't evaluate anything without a model") + return + await eval_model_cli(node, model_name, dataloader, args.batch_size) + else: + if not model_name: + print("Error: This train ain't leaving the station without a model") + return + await train_model_cli(node, model_name, dataloader, args.batch_size, args.iters, + save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir) else: - if not model_name: - print("Error: This train ain't leaving the station without a model") - return - await train_model_cli(node, model_name, dataloader, args.batch_size, args.iters, save_interval=args.save_every, checkpoint_dir=args.save_checkpoint_dir) - - else: - asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task - await asyncio.Event().wait() + asyncio.create_task(api.run(port=args.chatgpt_api_port)) # Start the API server as a non-blocking task + await asyncio.Event().wait() - if args.wait_for_peers > 0: - print("Cooldown to allow peers to exit gracefully") - for i in tqdm(range(50)): - await asyncio.sleep(.1) + if args.wait_for_peers > 0: + print("Cooldown to allow peers to exit gracefully") + for i in tqdm(range(50)): + await asyncio.sleep(.1) def run(): loop = None try: - loop = configure_uvloop() + loop = configure_event_loop() loop.run_until_complete(main()) except KeyboardInterrupt: print("\nShutdown requested... exiting") finally: - if loop: loop.close() + if loop: + loop.close() if __name__ == "__main__": - run() + run() \ No newline at end of file diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index ae28e757f..b4a5baeb3 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -152,7 +152,7 @@ async def device_capabilities() -> DeviceCapabilities: elif psutil.LINUX: return await linux_device_capabilities() elif psutil.WINDOWS: - return await windows_device_capabilities() + return windows_device_capabilities() else: return DeviceCapabilities( model="Unknown Device", diff --git a/setup.py b/setup.py index de242f544..64e3de21b 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,6 @@ "tqdm==4.66.4", "transformers==4.46.3", "uuid==1.30", - "uvloop==0.21.0", "tinygrad @ git+https://github.com/tinygrad/tinygrad.git@ec120ce6b9ce8e4ff4b5692566a683ef240e8bc8", ] @@ -50,7 +49,6 @@ if sys.platform.startswith("win32"): install_requires.extend(extras_require["windows"]) - def _add_gpu_requires(): global install_requires # Add Nvidia-GPU @@ -74,7 +72,6 @@ def _add_gpu_requires(): finally: pass - _add_gpu_requires() setup(