Skip to content

Commit

Permalink
Refactor of Models API, Support for calling downloads (#985)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaby authored Dec 18, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent b27b806 commit a43c13f
Showing 2 changed files with 189 additions and 60 deletions.
187 changes: 130 additions & 57 deletions api/src/serge/routers/model.py
Original file line number Diff line number Diff line change
@@ -2,9 +2,10 @@
import os
import shutil

import aiohttp

from fastapi import APIRouter, HTTPException
import huggingface_hub
from huggingface_hub import hf_hub_url
from serge.models.models import Families

from pathlib import Path
@@ -14,6 +15,8 @@
tags=["model"],
)

active_downloads = {}

WEIGHTS = "/usr/src/app/weights/"

models_file_path = Path(__file__).parent.parent / "data" / "models.json"
@@ -34,13 +37,59 @@
# Helper functions
async def is_model_installed(model_name: str) -> bool:
installed_models = await list_of_installed_models()
return f"{model_name}.bin" in installed_models
return any(file_name == f"{model_name}.bin" and not file_name.startswith(".") for file_name in installed_models)


async def get_file_size(file_path: str) -> int:
return os.stat(file_path).st_size


async def cleanup_model_resources(model_name: str):
model_repo, _, _ = models_info.get(model_name, (None, None, None))
if not model_repo:
print(f"No model repo found for {model_name}, cleanup may be incomplete.")
return

temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin")
lock_dir = os.path.join(WEIGHTS, ".locks", f"models--{model_repo.replace('/', '--')}")
cache_dir = os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}")

# Try to cleanup temporary file if it exists
if os.path.exists(temp_model_path):
try:
os.remove(temp_model_path)
except OSError as e:
print(f"Error removing temporary file for {model_name}: {e}")

# Remove lock file if it exists
if os.path.exists(lock_dir):
try:
shutil.rmtree(lock_dir)
except OSError as e:
print(f"Error removing lock directory for {model_name}: {e}")

# Remove cache directory if it exists
if os.path.exists(cache_dir):
try:
shutil.rmtree(cache_dir)
except OSError as e:
print(f"Error removing cache directory for {model_name}: {e}")


async def download_file(session: aiohttp.ClientSession, url: str, path: str) -> None:
async with session.get(url) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="Error downloading model")

# Write response content to file asynchronously
with open(path, "wb") as f:
while True:
chunk = await response.content.read(1024)
if not chunk:
break
f.write(chunk)


# Handlers
@model_router.get("/all")
async def list_of_all_models():
@@ -73,26 +122,15 @@ async def list_of_all_models():
return resp


@model_router.get("/downloadable")
async def list_of_downloadable_models():
files = os.listdir(WEIGHTS)
files = list(filter(lambda x: x.endswith(".bin"), files))

installed_models = [i.rstrip(".bin") for i in files]

return list(filter(lambda x: x not in installed_models, models_info.keys()))


@model_router.get("/installed")
async def list_of_installed_models():
# after iterating through the WEIGHTS directory, return location and filename
# Iterate through the WEIGHTS directory and return filenames that end with .bin and do not start with a dot
files = [
f"{model_location.replace(WEIGHTS, '')}/{bin_file}"
for model_location, directory, filenames in os.walk(WEIGHTS)
os.path.join(model_location.replace(WEIGHTS, "").lstrip("/"), bin_file)
for model_location, _, filenames in os.walk(WEIGHTS)
for bin_file in filenames
if os.path.splitext(bin_file)[1] == ".bin"
if bin_file.endswith(".bin") and not bin_file.startswith(".")
]
files = [i.lstrip("/") for i in files]
return files


@@ -102,18 +140,63 @@ async def download_model(model_name: str):
raise HTTPException(status_code=404, detail="Model not found")

try:
# Download file, and resume broken downloads
model_repo, filename, _ = models_info[model_name]
model_path = f"{WEIGHTS}{model_name}.bin"
await asyncio.to_thread(
huggingface_hub.hf_hub_download, repo_id=model_repo, filename=filename, local_dir=WEIGHTS, cache_dir=WEIGHTS, resume_download=True
)
# Rename file
os.rename(os.path.join(WEIGHTS, filename), os.path.join(WEIGHTS, model_path))
model_url = hf_hub_url(repo_id=model_repo, filename=filename)
temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin")
model_path = os.path.join(WEIGHTS, f"{model_name}.bin")

# Create an aiohttp session with timeout settings
timeout = aiohttp.ClientTimeout(total=300)
async with aiohttp.ClientSession(timeout=timeout) as session:
# Start the download and add to active_downloads
download_task = asyncio.create_task(download_file(session, model_url, temp_model_path))
active_downloads[model_name] = download_task
await download_task

# Rename the dotfile to its final name
os.rename(temp_model_path, model_path)

# Remove the entry from active_downloads after successful download
active_downloads.pop(model_name, None)

return {"message": f"Model {model_name} downloaded"}
except asyncio.CancelledError:
await cleanup_model_resources(model_name)
raise HTTPException(status_code=200, detail="Download cancelled")
except Exception as exc:
await cleanup_model_resources(model_name)
raise HTTPException(status_code=500, detail=f"Error downloading model: {exc}")


@model_router.post("/{model_name}/download/cancel")
async def cancel_download(model_name: str):
try:
task = active_downloads.get(model_name)
if not task:
raise HTTPException(status_code=404, detail="No active download for this model")

# Remove the entry from active downloads after cancellation
task.cancel()

# Remove entry from active downloads
active_downloads.pop(model_name, None)

# Wait for the task to be cancelled
try:
# Wait for the task to respond to cancellation
print(f"Waiting for download for {model_name} to be cancelled")
await task
except asyncio.CancelledError:
# Handle the expected cancellation exception
pass

# Cleanup resources
await cleanup_model_resources(model_name)

print(f"Download for {model_name} cancelled")
return {"message": f"Download for {model_name} cancelled"}
except Exception as e:
# Handle exceptions, possibly log them
raise HTTPException(status_code=500, detail=f"Error downloading model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error cancelling model download: {str(e)}")


@model_router.get("/{model_name}/download/status")
@@ -125,48 +208,38 @@ async def download_status(model_name: str):
model_repo, _, _ = models_info[model_name]

# Construct the path to the blobs directory
blobs_dir = os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}", "blobs")

# Check for the .incomplete file in the blobs directory
if os.path.exists(os.path.join(WEIGHTS, f"{model_name}.bin")):
currentsize = os.path.getsize(os.path.join(WEIGHTS, f"{model_name}.bin"))
return min(round(currentsize / filesize * 100, 1), 100)
elif os.path.exists(blobs_dir):
for file in os.listdir(blobs_dir):
if file.endswith(".incomplete"):
incomplete_file_path = os.path.join(blobs_dir, file)
# Check if the .incomplete file exists and calculate the download status
if os.path.exists(incomplete_file_path):
currentsize = os.path.getsize(incomplete_file_path)
return min(round(currentsize / filesize * 100, 1), 100)
return 0
temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin")
model_path = os.path.join(WEIGHTS, f"{model_name}.bin")

# Check if the model is currently being downloaded
task = active_downloads.get(model_name)

if os.path.exists(model_path):
currentsize = os.path.getsize(model_path)
progress = min(round(currentsize / filesize * 100, 1), 100)
return progress
elif task and not task.done():
# If the task is still running, check for incomplete files
if os.path.exists(temp_model_path):
currentsize = os.path.getsize(temp_model_path)
return min(round(currentsize / filesize * 100, 1), 100)
# If temp_model_path doesn't exist, the download is likely just starting, progress is 0
return 0
else:
# No active download and the file does not exist
return None


@model_router.delete("/{model_name}")
async def delete_model(model_name: str):
if f"{model_name}.bin" not in await list_of_installed_models():
raise HTTPException(status_code=404, detail="Model not found")

model_repo, _, _ = models_info.get(model_name, (None, None, None))
if not model_repo:
raise HTTPException(status_code=404, detail="Model info not found")

# Remove link to model file
try:
os.remove(os.path.join(WEIGHTS, f"{model_name}.bin"))
except OSError as e:
print(f"Error removing model file: {e}")

# Remove lock file
try:
shutil.rmtree(os.path.join(WEIGHTS, ".locks", f"models--{model_repo.replace('/', '--')}"))
except OSError as e:
print(f"Error removing lock directory: {e}")

# Remove cache directory
try:
shutil.rmtree(os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}"))
except OSError as e:
print(f"Error removing cache directory: {e}")
await cleanup_model_resources(model_name)

return {"message": f"Model {model_name} deleted"}
62 changes: 59 additions & 3 deletions web/src/routes/models/+page.svelte
Original file line number Diff line number Diff line change
@@ -122,7 +122,15 @@
* @param model - The model name.
* @param isAvailable - Boolean indicating if the model is available.
*/
async function handleModelAction(model: string, isAvailable: boolean) {
async function handleModelAction(
model: string,
isAvailable: boolean,
isDownloading: boolean = false,
) {
if (isDownloading) {
await cancelDownload(model);
return;
}
const url = `/api/model/${model}${isAvailable ? "" : "/download"}`;
const method = isAvailable ? "DELETE" : "POST";
@@ -218,6 +226,42 @@
$: downloadedOrDownloadingModels = data.models
.filter((model) => model.progress > 0 || model.available)
.sort((a, b) => a.name.localeCompare(b.name));
async function cancelDownload(modelName: string) {
try {
const response = await fetch(`/api/model/${modelName}/download/cancel`, {
method: "POST",
});
if (response.ok) {
console.log(`Download for ${modelName} cancelled successfully.`);
// Update UI based on successful cancellation
const modelIndex = data.models.findIndex((m) => m.name === modelName);
if (modelIndex !== -1) {
data.models[modelIndex].progress = 0;
data.models[modelIndex].available = false;
data.models = [...data.models]; // trigger reactivity
}
// Remove model from tracking and local storage
downloadingModels.delete(modelName);
const currentDownloads = JSON.parse(
localStorage.getItem("downloadingModels") || "[]",
);
const updatedDownloads = currentDownloads.filter(
(model: string) => model !== modelName,
);
localStorage.setItem(
"downloadingModels",
JSON.stringify(updatedDownloads),
);
} else {
console.error(`Failed to cancel download for ${modelName}`);
}
} catch (error) {
console.error(`Error cancelling download for ${modelName}:`, error);
}
}
</script>

<div class="top-section">
@@ -249,13 +293,25 @@
</div>
{/if}
{#if model.progress >= 100}
<p>Size: {model.size / 1e9}GB</p>
<p>Size: {(model.size / 1e9).toFixed(2)} GB</p>
<button
on:click={() => handleModelAction(model.name, model.available)}
class="btn btn-error mt-2"
>
<Icon icon="mdi:trash" width="32" height="32" />
</button>
{:else}
<button
on:click={() =>
handleModelAction(
model.name,
model.available,
model.progress > 0 && model.progress < 100,
)}
class="btn btn-error mt-2"
>
<Icon icon="mdi:cancel" width="32" height="32" />
</button>
{/if}
</div>
</div>
@@ -287,7 +343,7 @@
{#if models.length === 1}
<h3>{truncateString(model.name, 24)}</h3>
{/if}
<p>Size: {model.size / 1e9}GB</p>
<p>Size: {(model.size / 1e9).toFixed(2)} GB</p>
<button
on:click={() => handleModelAction(model.name, model.available)}
class="btn btn-primary mt-2"

0 comments on commit a43c13f

Please sign in to comment.