diff --git a/comps/lvms/video-llama/server/requirements.txt b/comps/lvms/video-llama/server/requirements.txt index 41dacfbd21..afbac6004b 100644 --- a/comps/lvms/video-llama/server/requirements.txt +++ b/comps/lvms/video-llama/server/requirements.txt @@ -31,4 +31,6 @@ torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cpu torchvision==0.14.1 --index-url https://download.pytorch.org/whl/cpu transformers uvicorn +validators webdataset +werkzeug diff --git a/comps/lvms/video-llama/server/server.py b/comps/lvms/video-llama/server/server.py index f54cdc65e4..20841732c5 100644 --- a/comps/lvms/video-llama/server/server.py +++ b/comps/lvms/video-llama/server/server.py @@ -5,12 +5,14 @@ import argparse import logging import os +import re from threading import Thread from urllib.parse import urlparse import decord import requests import uvicorn +import validators from extract_vl_embedding import VLEmbeddingExtractor as VL from fastapi import FastAPI, Query from fastapi.middleware.cors import CORSMiddleware @@ -21,6 +23,7 @@ from transformers import TextIteratorStreamer, set_seed from video_llama.common.registry import registry from video_llama.conversation.conversation_video import Chat +from werkzeug.utils import secure_filename # Initialize decord bridge and seed decord.bridge.set_bridge("torch") @@ -33,7 +36,7 @@ context_db = None streamer = None chat = None -VIDEO_DIR = "/home/user/videos" +VIDEO_DIR = "/home/user/comps/lvms/video-llama/server/data" CFG_PATH = "video_llama_config/video_llama_eval_only_vl.yaml" MODEL_TYPE = "llama_v2" @@ -161,6 +164,43 @@ def is_local_file(url): return not url.startswith("http://") and not url.startswith("https://") +def is_valid_url(url): + # Validate the URL's structure + validation = validators.url(url) + if not validation: + logging.error("URL is invalid") + return False + + # Parse the URL to components + parsed_url = urlparse(url) + + # Check the scheme + if parsed_url.scheme not in ["http", "https"]: + logging.error("URL scheme is invalid") + return False + + # Check for "../" in the path + if "../" in parsed_url.path: + logging.error("URL contains '../', which is not allowed") + return False + + # Check that the path only contains one "." for the file extension + if parsed_url.path.count(".") != 1: + logging.error("URL path does not meet the requirement of having only one '.'") + return False + + # If all checks pass, the URL is valid + logging.info("URL is valid") + return True + + +def is_valid_video(filename): + if re.match(r"^[a-zA-Z0-9-_]+\.(mp4)$", filename, re.IGNORECASE): + return secure_filename(filename) + else: + return False + + @app.get("/health") async def health() -> Response: """Health check.""" @@ -175,46 +215,54 @@ async def generate( prompt: str = Query(..., description="Query for Video-LLama", examples="What is the man doing?"), max_new_tokens: int = Query(150, description="Maximum number of tokens to generate", examples=150), ) -> StreamingResponse: - if not is_local_file(video_url): - parsed_url = urlparse(video_url) - video_name = os.path.basename(parsed_url.path) - else: - video_name = os.path.basename(video_url) - if video_name.lower().endswith(".mp4"): - logging.info(f"Format check passed, the file '{video_name}' is an MP4 file.") + if video_url.lower().endswith(".mp4"): + logging.info(f"Format check passed, the file '{video_url}' is an MP4 file.") else: - logging.info(f"Format check failed, the file '{video_name}' is not an MP4 file.") - return JSONResponse(status_code=400, content={"message": "Invalid file type. Only mp4 videos are allowed."}) - - if not is_local_file(video_url): - try: - video_path = os.path.join(VIDEO_DIR, video_name) - response = requests.get(video_url, stream=True) - - if response.status_code == 200: - with open(video_path, "wb") as file: - for chunk in response.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - file.write(chunk) - logging.info(f"File downloaded: {video_path}") - else: + logging.info(f"Format check failed, the file '{video_url}' is not an MP4 file.") + return JSONResponse(status_code=500, content={"message": "Invalid file type. Only mp4 videos are allowed."}) + + if is_local_file(video_url): + # validate the video name + if is_valid_video(video_url): + secure_video_name = is_valid_video(video_url) # only support video name without path + else: + return JSONResponse(status_code=500, content={"message": "Invalid file name."}) + + video_path = os.path.join(VIDEO_DIR, secure_video_name) + if os.path.exists(video_path): + logging.info(f"File found: {video_path}") + else: + logging.error(f"File not found: {video_path}") + return JSONResponse( + status_code=404, content={"message": "File not found. Only local files under data folder are allowed."} + ) + else: + # validate the remote URL + if not is_valid_url(video_url): + return JSONResponse(status_code=500, content={"message": "Invalid URL."}) + else: + parsed_url = urlparse(video_url) + video_path = os.path.join(VIDEO_DIR, os.path.basename(parsed_url.path)) + try: + response = requests.get(video_url, stream=True) + if response.status_code == 200: + with open(video_path, "wb") as file: + for chunk in response.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + file.write(chunk) + logging.info(f"File downloaded: {video_path}") + else: + logging.info(f"Error downloading file: {response.status_code}") + return JSONResponse(status_code=500, content={"message": "Error downloading file."}) + except Exception as e: logging.info(f"Error downloading file: {response.status_code}") return JSONResponse(status_code=500, content={"message": "Error downloading file."}) - except Exception as e: - logging.info(f"Error downloading file: {response.status_code}") - return JSONResponse(status_code=500, content={"message": "Error downloading file."}) - else: - # check if the video exist - video_path = video_url - if not os.path.exists(video_path): - logging.info(f"File not found: {video_path}") - return JSONResponse(status_code=404, content={"message": "File not found."}) + video_info = videoInfo(start_time=start, duration=duration, video_path=video_path) # format context and instruction instruction = f"{get_context(prompt,context_db)[0]}: {prompt}" - # logging.info("instruction:",instruction) return StreamingResponse(stream_res(video_info, instruction, max_new_tokens)) diff --git a/tests/test_lvms_video-llama.sh b/tests/test_lvms_video-llama.sh index 1e94982fb3..a9dcbf3a7f 100755 --- a/tests/test_lvms_video-llama.sh +++ b/tests/test_lvms_video-llama.sh @@ -62,7 +62,7 @@ function start_service() { } function validate_microservice() { - result=$(http_proxy="" curl http://localhost:5031/v1/lvm -X POST -d '{"video_url":"./data/silence_girl.mp4","chunk_start": 0,"chunk_duration": 7,"prompt":"What is the person doing?","max_new_tokens": 50}' -H 'Content-Type: application/json') + result=$(http_proxy="" curl http://localhost:5031/v1/lvm -X POST -d '{"video_url":"silence_girl.mp4","chunk_start": 0,"chunk_duration": 7,"prompt":"What is the person doing?","max_new_tokens": 50}' -H 'Content-Type: application/json') if [[ $result == *"silence"* ]]; then echo "Result correct." else