From a43c13fb3f470fecf5eeb8d275b3c450823da737 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sun, 17 Dec 2023 22:20:52 -0500 Subject: [PATCH] Refactor of Models API, Support for calling downloads (#985) --- api/src/serge/routers/model.py | 187 ++++++++++++++++++++--------- web/src/routes/models/+page.svelte | 62 +++++++++- 2 files changed, 189 insertions(+), 60 deletions(-) diff --git a/api/src/serge/routers/model.py b/api/src/serge/routers/model.py index fef9f02aaf5..04fd63c3af4 100644 --- a/api/src/serge/routers/model.py +++ b/api/src/serge/routers/model.py @@ -2,9 +2,10 @@ import os import shutil +import aiohttp from fastapi import APIRouter, HTTPException -import huggingface_hub +from huggingface_hub import hf_hub_url from serge.models.models import Families from pathlib import Path @@ -14,6 +15,8 @@ tags=["model"], ) +active_downloads = {} + WEIGHTS = "/usr/src/app/weights/" models_file_path = Path(__file__).parent.parent / "data" / "models.json" @@ -34,13 +37,59 @@ # Helper functions async def is_model_installed(model_name: str) -> bool: installed_models = await list_of_installed_models() - return f"{model_name}.bin" in installed_models + return any(file_name == f"{model_name}.bin" and not file_name.startswith(".") for file_name in installed_models) async def get_file_size(file_path: str) -> int: return os.stat(file_path).st_size +async def cleanup_model_resources(model_name: str): + model_repo, _, _ = models_info.get(model_name, (None, None, None)) + if not model_repo: + print(f"No model repo found for {model_name}, cleanup may be incomplete.") + return + + temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin") + lock_dir = os.path.join(WEIGHTS, ".locks", f"models--{model_repo.replace('/', '--')}") + cache_dir = os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}") + + # Try to cleanup temporary file if it exists + if os.path.exists(temp_model_path): + try: + os.remove(temp_model_path) + except OSError as e: + print(f"Error removing temporary file for {model_name}: {e}") + + # Remove lock file if it exists + if os.path.exists(lock_dir): + try: + shutil.rmtree(lock_dir) + except OSError as e: + print(f"Error removing lock directory for {model_name}: {e}") + + # Remove cache directory if it exists + if os.path.exists(cache_dir): + try: + shutil.rmtree(cache_dir) + except OSError as e: + print(f"Error removing cache directory for {model_name}: {e}") + + +async def download_file(session: aiohttp.ClientSession, url: str, path: str) -> None: + async with session.get(url) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="Error downloading model") + + # Write response content to file asynchronously + with open(path, "wb") as f: + while True: + chunk = await response.content.read(1024) + if not chunk: + break + f.write(chunk) + + # Handlers @model_router.get("/all") async def list_of_all_models(): @@ -73,26 +122,15 @@ async def list_of_all_models(): return resp -@model_router.get("/downloadable") -async def list_of_downloadable_models(): - files = os.listdir(WEIGHTS) - files = list(filter(lambda x: x.endswith(".bin"), files)) - - installed_models = [i.rstrip(".bin") for i in files] - - return list(filter(lambda x: x not in installed_models, models_info.keys())) - - @model_router.get("/installed") async def list_of_installed_models(): - # after iterating through the WEIGHTS directory, return location and filename + # Iterate through the WEIGHTS directory and return filenames that end with .bin and do not start with a dot files = [ - f"{model_location.replace(WEIGHTS, '')}/{bin_file}" - for model_location, directory, filenames in os.walk(WEIGHTS) + os.path.join(model_location.replace(WEIGHTS, "").lstrip("/"), bin_file) + for model_location, _, filenames in os.walk(WEIGHTS) for bin_file in filenames - if os.path.splitext(bin_file)[1] == ".bin" + if bin_file.endswith(".bin") and not bin_file.startswith(".") ] - files = [i.lstrip("/") for i in files] return files @@ -102,18 +140,63 @@ async def download_model(model_name: str): raise HTTPException(status_code=404, detail="Model not found") try: - # Download file, and resume broken downloads model_repo, filename, _ = models_info[model_name] - model_path = f"{WEIGHTS}{model_name}.bin" - await asyncio.to_thread( - huggingface_hub.hf_hub_download, repo_id=model_repo, filename=filename, local_dir=WEIGHTS, cache_dir=WEIGHTS, resume_download=True - ) - # Rename file - os.rename(os.path.join(WEIGHTS, filename), os.path.join(WEIGHTS, model_path)) + model_url = hf_hub_url(repo_id=model_repo, filename=filename) + temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin") + model_path = os.path.join(WEIGHTS, f"{model_name}.bin") + + # Create an aiohttp session with timeout settings + timeout = aiohttp.ClientTimeout(total=300) + async with aiohttp.ClientSession(timeout=timeout) as session: + # Start the download and add to active_downloads + download_task = asyncio.create_task(download_file(session, model_url, temp_model_path)) + active_downloads[model_name] = download_task + await download_task + + # Rename the dotfile to its final name + os.rename(temp_model_path, model_path) + + # Remove the entry from active_downloads after successful download + active_downloads.pop(model_name, None) + return {"message": f"Model {model_name} downloaded"} + except asyncio.CancelledError: + await cleanup_model_resources(model_name) + raise HTTPException(status_code=200, detail="Download cancelled") + except Exception as exc: + await cleanup_model_resources(model_name) + raise HTTPException(status_code=500, detail=f"Error downloading model: {exc}") + + +@model_router.post("/{model_name}/download/cancel") +async def cancel_download(model_name: str): + try: + task = active_downloads.get(model_name) + if not task: + raise HTTPException(status_code=404, detail="No active download for this model") + + # Remove the entry from active downloads after cancellation + task.cancel() + + # Remove entry from active downloads + active_downloads.pop(model_name, None) + + # Wait for the task to be cancelled + try: + # Wait for the task to respond to cancellation + print(f"Waiting for download for {model_name} to be cancelled") + await task + except asyncio.CancelledError: + # Handle the expected cancellation exception + pass + + # Cleanup resources + await cleanup_model_resources(model_name) + + print(f"Download for {model_name} cancelled") + return {"message": f"Download for {model_name} cancelled"} except Exception as e: - # Handle exceptions, possibly log them - raise HTTPException(status_code=500, detail=f"Error downloading model: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error cancelling model download: {str(e)}") @model_router.get("/{model_name}/download/status") @@ -125,21 +208,26 @@ async def download_status(model_name: str): model_repo, _, _ = models_info[model_name] # Construct the path to the blobs directory - blobs_dir = os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}", "blobs") - - # Check for the .incomplete file in the blobs directory - if os.path.exists(os.path.join(WEIGHTS, f"{model_name}.bin")): - currentsize = os.path.getsize(os.path.join(WEIGHTS, f"{model_name}.bin")) - return min(round(currentsize / filesize * 100, 1), 100) - elif os.path.exists(blobs_dir): - for file in os.listdir(blobs_dir): - if file.endswith(".incomplete"): - incomplete_file_path = os.path.join(blobs_dir, file) - # Check if the .incomplete file exists and calculate the download status - if os.path.exists(incomplete_file_path): - currentsize = os.path.getsize(incomplete_file_path) - return min(round(currentsize / filesize * 100, 1), 100) - return 0 + temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin") + model_path = os.path.join(WEIGHTS, f"{model_name}.bin") + + # Check if the model is currently being downloaded + task = active_downloads.get(model_name) + + if os.path.exists(model_path): + currentsize = os.path.getsize(model_path) + progress = min(round(currentsize / filesize * 100, 1), 100) + return progress + elif task and not task.done(): + # If the task is still running, check for incomplete files + if os.path.exists(temp_model_path): + currentsize = os.path.getsize(temp_model_path) + return min(round(currentsize / filesize * 100, 1), 100) + # If temp_model_path doesn't exist, the download is likely just starting, progress is 0 + return 0 + else: + # No active download and the file does not exist + return None @model_router.delete("/{model_name}") @@ -147,26 +235,11 @@ async def delete_model(model_name: str): if f"{model_name}.bin" not in await list_of_installed_models(): raise HTTPException(status_code=404, detail="Model not found") - model_repo, _, _ = models_info.get(model_name, (None, None, None)) - if not model_repo: - raise HTTPException(status_code=404, detail="Model info not found") - - # Remove link to model file try: os.remove(os.path.join(WEIGHTS, f"{model_name}.bin")) except OSError as e: print(f"Error removing model file: {e}") - # Remove lock file - try: - shutil.rmtree(os.path.join(WEIGHTS, ".locks", f"models--{model_repo.replace('/', '--')}")) - except OSError as e: - print(f"Error removing lock directory: {e}") - - # Remove cache directory - try: - shutil.rmtree(os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}")) - except OSError as e: - print(f"Error removing cache directory: {e}") + await cleanup_model_resources(model_name) return {"message": f"Model {model_name} deleted"} diff --git a/web/src/routes/models/+page.svelte b/web/src/routes/models/+page.svelte index e9601276b3c..351f9c67a6c 100644 --- a/web/src/routes/models/+page.svelte +++ b/web/src/routes/models/+page.svelte @@ -122,7 +122,15 @@ * @param model - The model name. * @param isAvailable - Boolean indicating if the model is available. */ - async function handleModelAction(model: string, isAvailable: boolean) { + async function handleModelAction( + model: string, + isAvailable: boolean, + isDownloading: boolean = false, + ) { + if (isDownloading) { + await cancelDownload(model); + return; + } const url = `/api/model/${model}${isAvailable ? "" : "/download"}`; const method = isAvailable ? "DELETE" : "POST"; @@ -218,6 +226,42 @@ $: downloadedOrDownloadingModels = data.models .filter((model) => model.progress > 0 || model.available) .sort((a, b) => a.name.localeCompare(b.name)); + + async function cancelDownload(modelName: string) { + try { + const response = await fetch(`/api/model/${modelName}/download/cancel`, { + method: "POST", + }); + + if (response.ok) { + console.log(`Download for ${modelName} cancelled successfully.`); + // Update UI based on successful cancellation + const modelIndex = data.models.findIndex((m) => m.name === modelName); + if (modelIndex !== -1) { + data.models[modelIndex].progress = 0; + data.models[modelIndex].available = false; + data.models = [...data.models]; // trigger reactivity + } + + // Remove model from tracking and local storage + downloadingModels.delete(modelName); + const currentDownloads = JSON.parse( + localStorage.getItem("downloadingModels") || "[]", + ); + const updatedDownloads = currentDownloads.filter( + (model: string) => model !== modelName, + ); + localStorage.setItem( + "downloadingModels", + JSON.stringify(updatedDownloads), + ); + } else { + console.error(`Failed to cancel download for ${modelName}`); + } + } catch (error) { + console.error(`Error cancelling download for ${modelName}:`, error); + } + }
@@ -249,13 +293,25 @@
{/if} {#if model.progress >= 100} -

Size: {model.size / 1e9}GB

+

Size: {(model.size / 1e9).toFixed(2)} GB

+ {:else} + {/if} @@ -287,7 +343,7 @@ {#if models.length === 1}

{truncateString(model.name, 24)}

{/if} -

Size: {model.size / 1e9}GB

+

Size: {(model.size / 1e9).toFixed(2)} GB