Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segment anything 2 pipeline image #185

Merged
merged 33 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
067abbb
feat(pipeline): add SAM2 image segmentation prototype
rickstaa Aug 13, 2024
1f8fe62
revert Dockerfile, requirements, add sam2 Dockerfile
pschroedl Aug 14, 2024
7ec3079
refactor: enhance SAM2 input handling and error management
rickstaa Aug 14, 2024
2828eb9
refactor: improve SAM2 return time
rickstaa Aug 14, 2024
8b8a508
Sam2 -> SegmentAnything2
pschroedl Aug 14, 2024
b96a2b7
update go bindings
pschroedl Aug 14, 2024
ecab548
update multipart.go binding with NewSegmentAnything2Writer
pschroedl Aug 15, 2024
b7b4cdc
update worker and multipart methods
pschroedl Aug 15, 2024
b200ab6
predictions -> scores, mask -> logits
pschroedl Aug 26, 2024
9b5b72e
add sam2 specific multipartwriter fields
pschroedl Aug 27, 2024
e503818
add segment-anything-2 to containerHostPorts
pschroedl Aug 27, 2024
36ebcce
fix pipeline name in worker.go
eliteprox Aug 28, 2024
51206d8
Merge branch 'segment_anything_2_pipeline_image' into add_segment_any…
rickstaa Aug 28, 2024
96a8157
revert Dockerfile, requirements, add sam2 Dockerfile
pschroedl Aug 14, 2024
389bd9f
Sam2 -> SegmentAnything2
pschroedl Aug 14, 2024
468552c
predictions -> scores, mask -> logits
pschroedl Aug 26, 2024
5e13f60
Merge branch 'main' into segment_anything_2_pipeline_image
rickstaa Aug 28, 2024
ce8acc8
feat: replace JSON.dump with str
rickstaa Aug 28, 2024
57bd5cf
move pipeline-specific dockerfile
Aug 30, 2024
02267cf
update openapi yaml
Aug 30, 2024
842466f
Merge branch 'main' into segment_anything_2_pipeline_image
pschroedl Aug 30, 2024
fdd1c31
add segment anything specific readme
Aug 30, 2024
74695ab
update go bindings
Aug 30, 2024
7178ddd
refactor: move SAM2 docker
rickstaa Sep 3, 2024
118803e
refactor: add FastAPI descriptions
rickstaa Sep 3, 2024
c520b49
refactor: improve sam2 route function name
rickstaa Sep 3, 2024
a5b4e73
chore(worker): update golang bindings
rickstaa Sep 3, 2024
323a376
refactor(runner): add media_type
rickstaa Sep 3, 2024
acd6407
chore(worker): remove debug patch
rickstaa Sep 3, 2024
50fcc14
feat(runnner): add SAM2 model download command
rickstaa Sep 3, 2024
fc8a33d
refactor(worker): change SAM2 multipart reader param order
rickstaa Sep 3, 2024
6cd1a4a
determine docker image in createContainer
pschroedl Sep 3, 2024
225dfd5
fix: fix examples
rickstaa Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case "segment-anything-2":
from app.pipelines.segment_anything_2 import SegmentAnything2Pipeline

return SegmentAnything2Pipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -82,6 +86,10 @@ def load_route(pipeline: str) -> any:
from app.routes import upscale

return upscale.router
case "segment-anything-2":
from app.routes import segment_anything_2

return segment_anything_2.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
41 changes: 41 additions & 0 deletions runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import List, Optional, Tuple

import PIL
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir
from app.routes.util import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor

ImageFile.LOAD_TRUNCATED_IMAGES = True

logger = logging.getLogger(__name__)


class SegmentAnything2Pipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()

self.tm = SAM2ImagePredictor.from_pretrained(
model_id=model_id,
device=torch_device,
**kwargs,
)

def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
try:
self.tm.set_image(image)
prediction = self.tm.predict(**kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return prediction

def __str__(self) -> str:
return f"Segment Anything 2 model_id={self.model_id}"
179 changes: 179 additions & 0 deletions runner/app/routes/segment_anything_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
import os
from typing import Annotated

import numpy as np
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import (
HTTPError,
InferenceError,
MasksResponse,
http_error,
json_str_to_np_array,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373.
@router.post(
"/segment-anything-2",
response_model=MasksResponse,
responses=RESPONSES,
description="Segment objects in an image.",
)
@router.post(
"/segment-anything-2/",
response_model=MasksResponse,
responses=RESPONSES,
include_in_schema=False,
)
async def segment_anything_2(
image: Annotated[
UploadFile, File(description="Image to segment.", media_type="image/*")
],
model_id: Annotated[
str, Form(description="Hugging Face model ID used for image generation.")
] = "",
point_coords: Annotated[
str,
Form(
description=(
"Nx2 array of point prompts to the model, where each point is in (X,Y) "
"in pixels."
)
),
] = None,
point_labels: Annotated[
str,
Form(
description=(
"Labels for the point prompts, where 1 indicates a foreground point "
"and 0 indicates a background point."
)
),
] = None,
box: Annotated[
str,
Form(
description=(
"A length 4 array given as a box prompt to the model, in XYXY format."
)
),
] = None,
mask_input: Annotated[
str,
Form(
description=(
"A low-resolution mask input to the model, typically from a previous "
"prediction iteration, with the form 1xHxW (H=W=256 for SAM)."
)
),
] = None,
multimask_output: Annotated[
bool,
Form(
description=(
"If true, the model will return three masks for ambiguous input "
"prompts, often producing better masks than a single prediction."
)
),
] = True,
return_logits: Annotated[
bool,
Form(
description=(
"If true, returns un-thresholded mask logits instead of a binary mask."
)
),
] = True,
normalize_coords: Annotated[
bool,
Form(
description=(
"If true, the point coordinates will be normalized to the range [0,1], "
"with point_coords expected to be with respect to image dimensions."
)
),
] = True,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

try:
point_coords = json_str_to_np_array(point_coords, var_name="point_coords")
point_labels = json_str_to_np_array(point_labels, var_name="point_labels")
box = json_str_to_np_array(box, var_name="box")
mask_input = json_str_to_np_array(mask_input, var_name="mask_input")
except ValueError as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)

try:
image = Image.open(image.file).convert("RGB")
masks, scores, low_res_mask_logits = pipeline(
image,
point_coords=point_coords,
point_labels=point_labels,
box=box,
mask_input=mask_input,
multimask_output=multimask_output,
return_logits=return_logits,
normalize_coords=normalize_coords,
)
except Exception as e:
logger.error(f"Segment Anything 2 error: {e}")
logger.exception(e)
if isinstance(e, InferenceError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("Segment Anything 2 error"),
)

# Return masks sorted by descending score as string.
sorted_ind = np.argsort(scores)[::-1]
return {
"masks": str(masks[sorted_ind].tolist()),
"scores": str(scores[sorted_ind].tolist()),
"logits": str(low_res_mask_logits[sorted_ind].tolist()),
}
60 changes: 59 additions & 1 deletion runner/app/routes/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import base64
import io
import json
import os
from typing import List
from typing import List, Optional

import numpy as np
from fastapi import UploadFile
from PIL import Image
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -30,6 +32,18 @@ class VideoResponse(BaseModel):
frames: List[List[Media]] = Field(..., description="The generated video frames.")


class MasksResponse(BaseModel):
"""Response model for object segmentation."""

masks: str = Field(..., description="The generated masks.")
scores: str = Field(
..., description="The model's confidence scores for each generated mask."
)
logits: str = Field(
..., description="The raw, unnormalized predictions (logits) for the masks."
)


class chunk(BaseModel):
"""A chunk of text with a timestamp."""

Expand All @@ -56,6 +70,22 @@ class HTTPError(BaseModel):
detail: APIError = Field(..., description="Detailed error information.")


class InferenceError(Exception):
"""Exception raised for errors during model inference."""

def __init__(self, message="Error during model execution", original_exception=None):
"""Initialize the exception.

Args:
message: The error message.
original_exception: The original exception that caused the error.
"""
if original_exception:
message = f"{message}: {original_exception}"
super().__init__(message)
self.original_exception = original_exception


def http_error(msg: str) -> HTTPError:
"""Create an HTTP error response with the specified message.

Expand Down Expand Up @@ -118,3 +148,31 @@ def file_exceeds_max_size(
except Exception as e:
print(f"Error checking file size: {e}")
return False


def json_str_to_np_array(
data: Optional[str], var_name: Optional[str] = None
) -> Optional[np.ndarray]:
"""Converts a JSON string to a NumPy array.

Args:
data: The JSON string to convert.
var_name: The name of the variable being converted. Used in error messages.

Returns:
The NumPy array if the conversion is successful, None otherwise.

Raises:
ValueError: If an error occurs during JSON parsing.
"""
if data:
try:
array = np.array(json.loads(data))
return array
except json.JSONDecodeError as e:
error_message = "Error parsing JSON"
if var_name:
error_message += f" for {var_name}"
error_message += f": {e}"
raise ValueError(error_message)
return None
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ function download_all_models() {

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Custom pipeline models.
huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models
}

# Enable HF transfer acceleration.
Expand Down
5 changes: 5 additions & 0 deletions runner/docker/Dockerfile.segment_anything_2
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
FROM livepeer/ai-runner:base

RUN pip install --no-cache-dir torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 xformers==0.0.27 git+https://github.com/facebookresearch/segment-anything-2.git@main#egg=sam-2

CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"]
36 changes: 36 additions & 0 deletions runner/docker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Runner Docker Images

This folder contains Dockerfiles for pipelines supported by the Livepeer AI network. The list is maintained by the Livepeer community and audited by the [Core AI team](https://explorer.livepeer.org/treasury/42084921863832634370966409987770520882792921083596034115019946998721416745190). In the future, we will enable custom pipelines to be used with the Livepeer AI network.

## Building a Pipeline-Specific Container

> [!NOTE]
> We are transitioning our existing pipelines to this new structure. As a result, the base container is currently somewhat bloated. In the future, the base image will contain only the necessary dependencies to run any pipeline.

All pipeline-specific containers are built on top of the base container found in the main [runner](../) folder and on [Docker Hub](https://hub.docker.com/r/livepeer/ai-runner). The base container includes the minimum dependencies to run any pipeline, while pipeline-specific containers add the necessary dependencies for their respective pipelines. This structure allows for faster build times, less dependency bloat, and easier maintenance.

### Steps to Build a Pipeline-Specific Container

To build a pipeline-specific container, you need to build the base container first. The base container is tagged as `base`, and the pipeline-specific container is built from the Dockerfile in the pipeline-specific folder. For example, to build the `segment-anything-2` pipeline-specific container, follow these steps:

1. **Navigate to the `ai-worker/runner` Directory**:

```bash
cd ai-worker/runner
```

2. **Build the Base Container**:

```bash
docker build -t livepeer/ai-runner:base .
```

This command builds the base container and tags it as `livepeer/ai-runner:base`.

3. **Build the `segment-anything-2` Pipeline-Specific Container**:

```bash
docker build -f docker/Dockerfile.segment_anything_2 -t livepeer/ai-runner:segment-anything-2 .
```

This command builds the `segment-anything-2` pipeline-specific container using the Dockerfile located at [docker/Dockerfile.segment_anything_2](docker/Dockerfile.segment_anything_2) and tags it as `livepeer/ai-runner:segment-anything-2`.
Loading