diff --git a/recipes/quickstart/finetuning/LLM_finetuning_overview.md b/recipes/quickstart/finetuning/LLM_finetuning_overview.md index ca79bcb96..34e79ff35 100644 --- a/recipes/quickstart/finetuning/LLM_finetuning_overview.md +++ b/recipes/quickstart/finetuning/LLM_finetuning_overview.md @@ -61,4 +61,4 @@ To boost the performance of fine-tuning with FSDP, we can make use a number of f - **Activation Checkpointing** which is a technique to save memory by discarding the intermediate activation in forward pass instead of keeping it in the memory with the cost recomputing them in the backward pass. FSDP Activation checkpointing is shard aware meaning we need to apply it after wrapping the model with FSDP. In our script we are making use of that. -- **auto_wrap_policy** Which is the way to specify how FSDP would partition the model, there is default support for transformer wrapping policy. This allows FSDP to form each FSDP unit ( partition of the model ) based on the transformer class in the model. To identify this layer in the model, you need to look at the layer that wraps both the attention layer and MLP. This helps FSDP have more fine-grained units for communication that help with optimizing the communication cost. +- **auto_wrap_policy** Which is the way to specify how FSDP would partition the model, there is default support for transformer wrapping policy. This allows FSDP to form each FSDP unit ( partition of the model ) based on the transformer class in the model. To identify this layer in the model, you need to look at the layer that wraps both the attention layer and MLP. This helps FSDP have more fine-grained units for communication that help with optimizing the communication cost. \ No newline at end of file diff --git a/recipes/quickstart/inference/local_inference/README.md b/recipes/quickstart/inference/local_inference/README.md index 0bf2ad9d7..8e27304a2 100644 --- a/recipes/quickstart/inference/local_inference/README.md +++ b/recipes/quickstart/inference/local_inference/README.md @@ -3,26 +3,43 @@ ## Hugging face setup **Important Note**: Before running the inference, you'll need your Hugging Face access token, which you can get at your Settings page [here](https://huggingface.co/settings/tokens). Then run `huggingface-cli login` and copy and paste your Hugging Face access token to complete the login to make sure the scripts can download Hugging Face models if needed. -## Multimodal Inference -For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library. +## Multimodal Inference and CLI inference with or without PEFT LoRA weights -The way to run this would be: -``` -python multi_modal_infer.py --image_path PATH_TO_IMAGE --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" -``` ---- -## Multi-modal Inferencing Using gradio UI for inferencing -For multi-modal inferencing using gradio UI we have added [multi_modal_infer_gradio_UI.py](multi_modal_infer_gradio_UI.py) which used gradio and transformers library. +### Model Overview +- Base model: `meta-llama/Llama-3.2-11B-Vision-Instruct` +- Uses PEFT library (v0.13.1) for efficient fine-tuning +- Supports vision-language tasks with instruction capabilities -### Steps to Run +### Features in +`multi_modal_infer.py` -The way to run this would be: -- Ensure having proper access to llama 3.2 vision models, then run the command given below +All functionality has been consolidated into a single file with three main modes, use `huggingface-cli login`: +### Steps to run are given below: +1. **Basic Inference** +```bash +python multi_modal_infer.py \ + --image_path "path/to/image.jpg" \ + --prompt_text "Describe this image" \ + --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \ +``` +2. **Gradio UI Mode** +```bash +python multi_modal_infer.py \ + --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \ + --gradio_ui ``` -python multi_modal_infer_gradio_UI.py --hf_token + +3. **LoRA Fine-tuning Integration** +```bash +python multi_modal_infer.py \ + --image_path "path/to/image.jpg" \ + --prompt_text "Describe this image" \ + --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \ + --finetuning_path "path/to/lora/weights" ``` + ## Text-only Inference For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments. diff --git a/recipes/quickstart/inference/local_inference/multi_modal_infer.py b/recipes/quickstart/inference/local_inference/multi_modal_infer.py index 27d45b5f1..071dc8683 100644 --- a/recipes/quickstart/inference/local_inference/multi_modal_infer.py +++ b/recipes/quickstart/inference/local_inference/multi_modal_infer.py @@ -1,108 +1,206 @@ import argparse import os import sys - import torch from accelerate import Accelerator from PIL import Image as PIL_Image from transformers import MllamaForConditionalGeneration, MllamaProcessor - +from peft import PeftModel +import gradio as gr +from huggingface_hub import HfFolder +# Initialize accelerator accelerator = Accelerator() - device = accelerator.device # Constants DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct" +MAX_OUTPUT_TOKENS = 2048 +MAX_IMAGE_SIZE = (1120, 1120) + + +def get_hf_token(): + """Retrieve Hugging Face token from the cache or environment.""" + # Check if a token is explicitly set in the environment + token = os.getenv("HUGGINGFACE_TOKEN") + if token: + return token + + # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login) + token = HfFolder.get_token() + if token: + return token + print("Hugging Face token not found. Please login using `huggingface-cli login`.") + sys.exit(1) -def load_model_and_processor(model_name: str): - """ - Load the model and processor based on the 11B or 90B model. - """ + +def load_model_and_processor(model_name: str, finetuning_path: str = None): + """Load model and processor with optional LoRA adapter""" + print(f"Loading model: {model_name}") + hf_token = get_hf_token() model = MllamaForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.bfloat16, use_safetensors=True, device_map=device, + token=hf_token ) - processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True) - + processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True) + + if finetuning_path and os.path.exists(finetuning_path): + print(f"Loading LoRA adapter from '{finetuning_path}'...") + model = PeftModel.from_pretrained( + model, + finetuning_path, + is_adapter=True, + torch_dtype=torch.bfloat16 + ) + print("LoRA adapter merged successfully") + model, processor = accelerator.prepare(model, processor) return model, processor +def process_image(image_path: str = None, image = None) -> PIL_Image.Image: + """Process and validate image input""" + if image is not None: + return image.convert("RGB") + if image_path and os.path.exists(image_path): + return PIL_Image.open(image_path).convert("RGB") + raise ValueError("No valid image provided") -def process_image(image_path: str) -> PIL_Image.Image: - """ - Open and convert an image from the specified path. - """ - if not os.path.exists(image_path): - print(f"The image file '{image_path}' does not exist.") - sys.exit(1) - with open(image_path, "rb") as f: - return PIL_Image.open(f).convert("RGB") - - -def generate_text_from_image( - model, processor, image, prompt_text: str, temperature: float, top_p: float -): - """ - Generate text from an image using the model and processor. - """ +def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float): + """Generate text from image using model""" conversation = [ - { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "text": prompt_text}], - } + {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]} ] - prompt = processor.apply_chat_template( - conversation, add_generation_prompt=True, tokenize=False - ) + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) inputs = processor(image, prompt, return_tensors="pt").to(device) - output = model.generate( - **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512 - ) - return processor.decode(output[0])[len(prompt) :] - - -def main( - image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str -): - """ - Call all the functions. - """ - model, processor = load_model_and_processor(model_name) - image = process_image(image_path) - result = generate_text_from_image( - model, processor, image, prompt_text, temperature, top_p - ) - print("Generated Text: " + result) - + output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS) + return processor.decode(output[0])[len(prompt):] + +def gradio_interface(model_name: str): + """Create Gradio UI with LoRA support""" + # Initialize model state + current_model = {"model": None, "processor": None} + + def load_or_reload_model(enable_lora: bool, lora_path: str = None): + current_model["model"], current_model["processor"] = load_model_and_processor( + model_name, + lora_path if enable_lora else None + ) + return "Model loaded successfully" + (" with LoRA" if enable_lora else "") + + def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history): + if image is not None: + try: + processed_image = process_image(image=image) + result = generate_text_from_image( + current_model["model"], + current_model["processor"], + processed_image, + user_prompt, + temperature, + top_p + ) + history.append((user_prompt, result)) + except Exception as e: + history.append((user_prompt, f"Error: {str(e)}")) + return history + + def clear_chat(): + return [] + + with gr.Blocks() as demo: + gr.HTML("

Llama Vision Model Interface

") + + with gr.Row(): + with gr.Column(scale=1): + # Model loading controls + with gr.Group(): + enable_lora = gr.Checkbox(label="Enable LoRA", value=False) + lora_path = gr.Textbox( + label="LoRA Weights Path", + placeholder="Path to LoRA weights folder", + visible=False + ) + load_status = gr.Textbox(label="Load Status", interactive=False) + load_button = gr.Button("Load/Reload Model") + + # Image and parameter controls + image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512) + temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1) + top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1) + top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1) + max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50) + + with gr.Column(scale=2): + chat_history = gr.Chatbot(label="Chat", height=512) + user_prompt = gr.Textbox( + show_label=False, + placeholder="Enter your prompt", + lines=2 + ) + + with gr.Row(): + generate_button = gr.Button("Generate") + clear_button = gr.Button("Clear") + + # Event handlers + enable_lora.change( + fn=lambda x: gr.update(visible=x), + inputs=[enable_lora], + outputs=[lora_path] + ) + + load_button.click( + fn=load_or_reload_model, + inputs=[enable_lora, lora_path], + outputs=[load_status] + ) + + generate_button.click( + fn=describe_image, + inputs=[ + image_input, user_prompt, temperature, + top_k, top_p, max_tokens, chat_history + ], + outputs=[chat_history] + ) + + clear_button.click(fn=clear_chat, outputs=[chat_history]) + + # Initial model load + load_or_reload_model(False) + return demo + +def main(args): + """Main execution flow""" + if args.gradio_ui: + demo = gradio_interface(args.model_name) + demo.launch() + else: + model, processor = load_model_and_processor( + args.model_name, + args.finetuning_path + ) + image = process_image(image_path=args.image_path) + result = generate_text_from_image( + model, processor, image, + args.prompt_text, + args.temperature, + args.top_p + ) + print("Generated Text:", result) if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate text from an image and prompt using the 3.2 MM Llama model." - ) - parser.add_argument("--image_path", type=str, help="Path to the image file") - parser.add_argument( - "--prompt_text", type=str, help="Prompt text to describe the image" - ) - parser.add_argument( - "--temperature", - type=float, - default=0.7, - help="Temperature for generation (default: 0.7)", - ) - parser.add_argument( - "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)" - ) - parser.add_argument( - "--model_name", - type=str, - default=DEFAULT_MODEL, - help=f"Model name (default: '{DEFAULT_MODEL}')", - ) - + parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support") + parser.add_argument("--image_path", type=str, help="Path to the input image") + parser.add_argument("--prompt_text", type=str, help="Prompt text for the image") + parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") + parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling") + parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name") + parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights") + parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI") + args = parser.parse_args() - main( - args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name - ) + main(args) diff --git a/recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py b/recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py deleted file mode 100644 index 5119ac7c3..000000000 --- a/recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py +++ /dev/null @@ -1,157 +0,0 @@ -import gradio as gr -import torch -import os -from PIL import Image -from accelerate import Accelerator -from transformers import MllamaForConditionalGeneration, AutoProcessor -import argparse # Import argparse - -# Parse the command line arguments -parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model") -parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token") -args = parser.parse_args() - -# Hugging Face token -hf_token = args.hf_token - -# Initialize Accelerator -accelerate = Accelerator() -device = accelerate.device - -# Set memory management for PyTorch -os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' # or adjust size as needed - -# Model ID -model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" - -# Load model with the Hugging Face token -model = MllamaForConditionalGeneration.from_pretrained( - model_id, - torch_dtype=torch.bfloat16, - device_map=device, - use_auth_token=hf_token # Pass the Hugging Face token here -) - -# Load the processor -processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token) - -# Visual theme -visual_theme = gr.themes.Default() # Default, Soft or Monochrome - -# Constants -MAX_OUTPUT_TOKENS = 2048 -MAX_IMAGE_SIZE = (1120, 1120) - -# Function to process the image and generate a description -def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history): - # Initialize cleaned_output variable - cleaned_output = "" - - if image is not None: - # Resize image if necessary - image = image.resize(MAX_IMAGE_SIZE) - prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:" - # Preprocess the image and prompt - inputs = processor(image, prompt, return_tensors="pt").to(device) - else: - # Text-only input if no image is provided - prompt = f"<|begin_of_text|>{user_prompt} Answer:" - # Preprocess the prompt only (no image) - inputs = processor(prompt, return_tensors="pt").to(device) - - # Generate output with model - output = model.generate( - **inputs, - max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS), - temperature=temperature, - top_k=top_k, - top_p=top_p - ) - - # Decode the raw output - raw_output = processor.decode(output[0]) - - # Clean up the output to remove system tokens - cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "") - - # Ensure the prompt is not repeated in the output - if cleaned_output.startswith(user_prompt): - cleaned_output = cleaned_output[len(user_prompt):].strip() - - # Append the new conversation to the history - history.append((user_prompt, cleaned_output)) - - return history - - -# Function to clear the chat history -def clear_chat(): - return [] - -# Gradio Interface -def gradio_interface(): - with gr.Blocks(visual_theme) as demo: - gr.HTML( - """ -

- meta-llama/Llama-3.2-11B-Vision-Instruct -

- """) - with gr.Row(): - # Left column with image and parameter inputs - with gr.Column(scale=1): - image_input = gr.Image( - label="Image", - type="pil", - image_mode="RGB", - height=512, # Set the height - width=512 # Set the width - ) - - # Parameter sliders - temperature = gr.Slider( - label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True) - top_k = gr.Slider( - label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True) - top_p = gr.Slider( - label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True) - max_tokens = gr.Slider( - label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True) - - # Right column with the chat interface - with gr.Column(scale=2): - chat_history = gr.Chatbot(label="Chat", height=512) - - # User input box for prompt - user_prompt = gr.Textbox( - show_label=False, - container=False, - placeholder="Enter your prompt", - lines=2 - ) - - # Generate and Clear buttons - with gr.Row(): - generate_button = gr.Button("Generate") - clear_button = gr.Button("Clear") - - # Define the action for the generate button - generate_button.click( - fn=describe_image, - inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history], - outputs=[chat_history] - ) - - # Define the action for the clear button - clear_button.click( - fn=clear_chat, - inputs=[], - outputs=[chat_history] - ) - - return demo - -# Launch the interface -demo = gradio_interface() -# demo.launch(server_name="0.0.0.0", server_port=12003) -demo.launch() \ No newline at end of file