From 8e3d9f3acf915f98eff0c64b6af10890b7e2a993 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Nov 2024 13:09:34 -0800 Subject: [PATCH] Make the chat distributed --- llms/mlx_lm/chat.py | 53 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 85d32d5fc..e60c24b16 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -14,6 +14,28 @@ DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +MAX_PROMPT_CHARS = 16384 + + +def share_message(world, prompt): + if world.size() == 1: + return prompt + + if world.rank() == 0: + prompt_array = mx.array(prompt.encode()) + prompt_array = mx.concatenate( + [prompt_array, mx.zeros(MAX_PROMPT_CHARS - len(x), dtype=mx.uint8)] + ) + + else: + prompt_array = mx.zeros(MAX_PROMPT_CHARS, dtype=mx.uint8) + + prompt_array = mx.distributed.all_sum(prompt_array) + mx.eval(prompt_array) + prompt = bytes(prompt_array) + idx = prompt.index(b'\x00'*4) + return prompt[:idx].decode() + def setup_arg_parser(): """Set up and return the argument parser.""" @@ -53,6 +75,7 @@ def setup_arg_parser(): def main(): + world = mx.distributed.init() parser = setup_arg_parser() args = parser.parse_args() @@ -62,18 +85,27 @@ def main(): args.model, adapter_path=args.adapter_path, tokenizer_config={"trust_remote_code": True}, + sequential_load=mx.distributed.init().size() > 1, ) - print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") + print(f"Node {world.rank()} of {world.size()}", flush=True) + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.", flush=True) + world.barrier() prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: - query = input(">> ") - if query == "q": + prompt = None + if world.rank() == 0: + query = input(">> ") + if query == "q": + prompt = query + else: + messages = [{"role": "user", "content": query}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt = share_message(world, prompt) + if prompt == "q": break - messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) for response in stream_generate( model, tokenizer, @@ -83,9 +115,12 @@ def main(): top_p=args.top_p, prompt_cache=prompt_cache, ): - print(response, flush=True, end="") - print() + if world.rank() == 0: + print(response, flush=True, end="") + if world.rank() == 0: + print() if __name__ == "__main__": main() +