Skip to content

Commit

Permalink
some fixes and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
VladOS95-cyber committed Oct 20, 2024
1 parent 81905e1 commit c9bb16e
Showing 1 changed file with 42 additions and 29 deletions.
71 changes: 42 additions & 29 deletions examples/inference/distributed/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import av
from huggingface_hub import hf_hub_download
import json
from accelerate.utils import gather_object
import queue
from concurrent.futures import ThreadPoolExecutor
import pathlib

START_TIME = time.strftime("%Y%m%d_%H%M%S")
DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
Expand All @@ -31,22 +33,23 @@
"""
Example:
accelerate launch --num_processes=2 llava_next_video.py
accelerate launch llava_next_video.py
"""


def main(
model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf",
save_dir: str = "./evaluation/examples",
dtype: str = "fp16",
num_workers: int = 1,
low_mem: bool = True,
):
# Start up the distributed environment without needing the Accelerator.
distributed_state = PartialState()

processor = LlavaNextVideoProcessor.from_pretrained(model_name)
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
model_name, torch_dtype=DTYPE_MAP[dtype], low_cpu_mem_usage=low_mem, device_map=distributed_state.device
model_name, torch_dtype=dtype[DTYPE_MAP], low_cpu_mem_usage=low_mem, device_map=distributed_state.device
)

if distributed_state.is_main_process:
Expand Down Expand Up @@ -109,33 +112,43 @@ def main(
for i in range(0, len(conversations))
]

count = 0
distributed_state.num_processes = len(formatted_prompts)
with distributed_state.split_between_processes(formatted_prompts) as prompts:
input = processor(text=prompts, videos=video, return_tensors="pt").to(model.device)
output = model.generate(**input, max_new_tokens=60)
generated_text = processor.decode(output[0][2:], skip_special_tokens=True)

distributed_state.wait_for_everyone()

answers = gather_object(generated_text)
input_prompts = gather_object(prompts)

if distributed_state.is_main_process:
for ans, prompt in zip(answers, input_prompts):
count += 1
example_file = f"example_{count}"
temp_dir = os.path.join(save_dir, example_file)

metadata = {
"prompt": prompt,
"generated_answer": ans,
}
with open(temp_dir, "w") as f:
json.dump(metadata, f, indent=4)
def save_results(output_queue: queue.Queue, output_dir: pathlib.Path):
count = 0
while True:
try:
item = output_queue.get(timeout=5)
if item is None:
break
example_file = f"example_{count}"
temp_dir = os.path.join(output_dir, example_file)

metadata = {
"prompt": item[0],
"generated_answer": item[1],
}
with open(temp_dir, "w") as f:
json.dump(metadata, f, indent=4)
count += 1

except queue.Empty:
continue

if distributed_state.is_main_process:
print(f">>> Video answer generation Finished. Saved in {save_dir}")
distributed_state.num_processes = len(formatted_prompts)
output_queue = queue.Queue()
save_thread = ThreadPoolExecutor(max_workers=num_workers)
save_future = save_thread.submit(save_results, output_queue, save_dir)

try:
with distributed_state.split_between_processes(formatted_prompts) as prompt:
input = processor(text=prompt, videos=video, padding=True, return_tensors="pt").to(model.device)
output = model.generate(**input, max_new_tokens=60)
generated_text = processor.decode(output[0][2:], skip_special_tokens=True)
output_queue.put((prompt, generated_text))
finally:
output_queue.put(None)
save_thread.shutdown(wait=True)

save_future.result()


def read_video_pyav(container, indices):
Expand Down

0 comments on commit c9bb16e

Please sign in to comment.