Skip to content

Commit

Permalink
Merge pull request #11 from computational-cell-analytics/custom-ckpt
Browse files Browse the repository at this point in the history
Add support for custom checkpoint
  • Loading branch information
ksugar authored Sep 13, 2023
2 parents 6be6c93 + 803e7a4 commit 24b85ce
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/samapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,12 @@ def _get_device() -> str:
return device


def get_sam_model(model_type: ModelType):
def get_sam_model(model_type: ModelType, checkpoint_url: Optional[str] = None):
sam = sam_model_registry[model_type]()
sam.load_state_dict(SAM_CHECKPOINTS[model_type])
if checkpoint_url is None:
sam.load_state_dict(SAM_CHECKPOINTS[model_type])
else:
sam.load_state_dict(load_state_dict_from_url(checkpoint_url))
return sam


Expand All @@ -98,6 +101,7 @@ class SAMBody(BaseModel):
b64img: str
b64mask: Optional[str] = None
multimask_output: bool = False
checkpoint_url: Optional[str] = None


@app.post("/sam/")
Expand All @@ -106,7 +110,7 @@ async def predict_sam(body: SAMBody):
global predictor
global last_image
if body.type != sam_type:
predictor = SamPredictor(get_sam_model(body.type).to(device=device))
predictor = SamPredictor(get_sam_model(body.type, body.checkpoint_url).to(device=device))
sam_type = body.type
last_image = None
if last_image != body.b64img:
Expand Down Expand Up @@ -168,7 +172,7 @@ async def automatic_mask_generator(body: SAMAutoMaskBody):
global predictor
global last_image
if body.type != sam_type:
predictor = SamPredictor(get_sam_model(body.type).to(device=device))
predictor = SamPredictor(get_sam_model(body.type, body.checkpoint_url).to(device=device))
sam_type = body.type
last_image = None
if last_image != body.b64img:
Expand Down

0 comments on commit 24b85ce

Please sign in to comment.