diff --git a/examples/inference/distributed/llava_next_video.py b/examples/inference/distributed/llava_next_video.py index 4c18e65843b..50dc81face6 100644 --- a/examples/inference/distributed/llava_next_video.py +++ b/examples/inference/distributed/llava_next_video.py @@ -22,7 +22,9 @@ import av from huggingface_hub import hf_hub_download import json -from accelerate.utils import gather_object +import queue +from concurrent.futures import ThreadPoolExecutor +import pathlib START_TIME = time.strftime("%Y%m%d_%H%M%S") DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} @@ -31,7 +33,7 @@ """ Example: -accelerate launch --num_processes=2 llava_next_video.py +accelerate launch llava_next_video.py """ @@ -39,6 +41,7 @@ def main( model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf", save_dir: str = "./evaluation/examples", dtype: str = "fp16", + num_workers: int = 1, low_mem: bool = True, ): # Start up the distributed environment without needing the Accelerator. @@ -46,7 +49,7 @@ def main( processor = LlavaNextVideoProcessor.from_pretrained(model_name) model = LlavaNextVideoForConditionalGeneration.from_pretrained( - model_name, torch_dtype=DTYPE_MAP[dtype], low_cpu_mem_usage=low_mem, device_map=distributed_state.device + model_name, torch_dtype=dtype[DTYPE_MAP], low_cpu_mem_usage=low_mem, device_map=distributed_state.device ) if distributed_state.is_main_process: @@ -109,33 +112,43 @@ def main( for i in range(0, len(conversations)) ] - count = 0 - distributed_state.num_processes = len(formatted_prompts) - with distributed_state.split_between_processes(formatted_prompts) as prompts: - input = processor(text=prompts, videos=video, return_tensors="pt").to(model.device) - output = model.generate(**input, max_new_tokens=60) - generated_text = processor.decode(output[0][2:], skip_special_tokens=True) - - distributed_state.wait_for_everyone() - - answers = gather_object(generated_text) - input_prompts = gather_object(prompts) - - if distributed_state.is_main_process: - for ans, prompt in zip(answers, input_prompts): - count += 1 - example_file = f"example_{count}" - temp_dir = os.path.join(save_dir, example_file) - - metadata = { - "prompt": prompt, - "generated_answer": ans, - } - with open(temp_dir, "w") as f: - json.dump(metadata, f, indent=4) + def save_results(output_queue: queue.Queue, output_dir: pathlib.Path): + count = 0 + while True: + try: + item = output_queue.get(timeout=5) + if item is None: + break + example_file = f"example_{count}" + temp_dir = os.path.join(output_dir, example_file) + + metadata = { + "prompt": item[0], + "generated_answer": item[1], + } + with open(temp_dir, "w") as f: + json.dump(metadata, f, indent=4) + count += 1 + + except queue.Empty: + continue - if distributed_state.is_main_process: - print(f">>> Video answer generation Finished. Saved in {save_dir}") + distributed_state.num_processes = len(formatted_prompts) + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=num_workers) + save_future = save_thread.submit(save_results, output_queue, save_dir) + + try: + with distributed_state.split_between_processes(formatted_prompts) as prompt: + input = processor(text=prompt, videos=video, padding=True, return_tensors="pt").to(model.device) + output = model.generate(**input, max_new_tokens=60) + generated_text = processor.decode(output[0][2:], skip_special_tokens=True) + output_queue.put((prompt, generated_text)) + finally: + output_queue.put(None) + save_thread.shutdown(wait=True) + + save_future.result() def read_video_pyav(container, indices):