From 2c7d2500e633a8a913d3a544d8ab4d6ce5379e94 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 4 Oct 2024 14:56:32 -0700 Subject: [PATCH] [api] Improve Sam2Translator for PyTorch traced model (#3495) Fixes: #3484 --- .../cv/translator/Sam2Translator.java | 31 +++++++++++- .../java/ai/djl/pytorch/zoo/PtModelZoo.java | 6 +-- examples/docs/segment_anything_2.md | 48 ++++++++----------- examples/docs/trace_sam2_img.py | 48 ++++++++----------- 4 files changed, 74 insertions(+), 59 deletions(-) diff --git a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java index f1b75dd95d8..751019c760f 100644 --- a/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java +++ b/api/src/main/java/ai/djl/modality/cv/translator/Sam2Translator.java @@ -58,6 +58,7 @@ public class Sam2Translator implements NoBatchifyTranslator predictor; private String encoderPath; + private String encodeMethod; /** Constructs a {@code Sam2Translator} instance. */ public Sam2Translator(Builder builder) { @@ -66,12 +67,19 @@ public Sam2Translator(Builder builder) { pipeline.add(new ToTensor()); pipeline.add(new Normalize(MEAN, STD)); this.encoderPath = builder.encoderPath; + this.encodeMethod = builder.encodeMethod; } /** {@inheritDoc} */ @Override public void prepare(TranslatorContext ctx) throws IOException, ModelException { if (encoderPath == null) { + // PyTorch model + if (encodeMethod != null) { + Model model = ctx.getModel(); + predictor = model.newPredictor(new NoopTranslator(null)); + model.getNDManager().attachInternal(UUID.randomUUID().toString(), predictor); + } return; } Model model = ctx.getModel(); @@ -111,7 +119,15 @@ public NDList processInput(TranslatorContext ctx, Sam2Input input) throws Except return new NDList(array, locations, labels); } - NDList embeddings = predictor.predict(new NDList(array)); + NDList embeddings; + if (encodeMethod == null) { + embeddings = predictor.predict(new NDList(array)); + } else { + NDArray placeholder = manager.create(""); + placeholder.setName("module_method:" + encodeMethod); + embeddings = predictor.predict(new NDList(placeholder, array)); + } + NDArray mask = manager.zeros(new Shape(1, 1, 256, 256)); NDArray hasMask = manager.zeros(new Shape(1)); return new NDList( @@ -173,9 +189,11 @@ public static Builder builder(Map arguments) { public static class Builder { String encoderPath; + String encodeMethod; Builder(Map arguments) { encoderPath = ArgumentsUtil.stringValue(arguments, "encoder"); + encodeMethod = ArgumentsUtil.stringValue(arguments, "encode_method"); } /** @@ -189,6 +207,17 @@ public Builder optEncoderPath(String encoderPath) { return this; } + /** + * Sets the module name for encode method. + * + * @param encodeMethod the module name for encode method + * @return the builder + */ + public Builder optEncodeMethod(String encodeMethod) { + this.encodeMethod = encodeMethod; + return this; + } + /** * Builds the translator. * diff --git a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java index b92fea429b3..5ea9920b881 100644 --- a/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java +++ b/engines/pytorch/pytorch-model-zoo/src/main/java/ai/djl/pytorch/zoo/PtModelZoo.java @@ -43,10 +43,8 @@ public class PtModelZoo extends ModelZoo { addModel( REPOSITORY.model(CV.IMAGE_CLASSIFICATION, GROUP_ID, "resnet18_embedding", "0.0.1")); addModel(REPOSITORY.model(CV.INSTANCE_SEGMENTATION, GROUP_ID, "yolov8n-seg", "0.0.1")); - addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.1")); - addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny-gpu", "0.0.1")); - addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.1")); - addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large-gpu", "0.0.1")); + addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-tiny", "0.0.2")); + addModel(REPOSITORY.model(CV.MASK_GENERATION, GROUP_ID, "sam2-hiera-large", "0.0.2")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "ssd", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov5s", "0.0.1")); addModel(REPOSITORY.model(CV.OBJECT_DETECTION, GROUP_ID, "yolov8n", "0.0.1")); diff --git a/examples/docs/segment_anything_2.md b/examples/docs/segment_anything_2.md index 768a666b87f..6c66bbe9138 100644 --- a/examples/docs/segment_anything_2.md +++ b/examples/docs/segment_anything_2.md @@ -73,15 +73,21 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor from torch import nn -class SAM2ImageEncoder(nn.Module): +class Sam2Wrapper(nn.Module): - def __init__(self, sam_model: SAM2Base) -> None: + def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: super().__init__() self.model = sam_model self.image_encoder = sam_model.image_encoder self.no_mem_embed = sam_model.no_mem_embed + self.mask_decoder = sam_model.sam_mask_decoder + self.prompt_encoder = sam_model.sam_prompt_encoder + self.img_size = sam_model.image_size + self.multimask_output = multimask_output + self.sparse_embedding = None - def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tuple[Any, Any, Any]: backbone_out = self.image_encoder(x) backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0( backbone_out["backbone_fpn"][0]) @@ -106,18 +112,6 @@ class SAM2ImageEncoder(nn.Module): return feats[0], feats[1], feats[2] - -class SAM2ImageDecoder(nn.Module): - - def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: - super().__init__() - self.mask_decoder = sam_model.sam_mask_decoder - self.prompt_encoder = sam_model.sam_prompt_encoder - self.model = sam_model - self.img_size = sam_model.image_size - self.multimask_output = multimask_output - self.sparse_embedding = None - @torch.no_grad() def forward( self, @@ -205,17 +199,13 @@ def trace_model(model_id: str): device = torch.device("cpu") model_name = f"{model_id[9:]}" - os.makedirs(model_name) + os.makedirs(model_name, exist_ok=True) predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device) - encoder = SAM2ImageEncoder(predictor.model) - decoder = SAM2ImageDecoder(predictor.model, True) + model = Sam2Wrapper(predictor.model, True) input_image = torch.ones(1, 3, 1024, 1024).to(device) - high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image) - - converted = torch.jit.trace(encoder, input_image) - torch.jit.save(converted, f"model_name/encoder.pt") + high_res_feats_0, high_res_feats_1, image_embed = model.encode(input_image) # trace decoder model embed_size = ( @@ -232,10 +222,14 @@ def trace_model(model_id: str): mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float) has_mask_input = torch.tensor([1], dtype=torch.float) - converted = torch.jit.trace( - decoder, (image_embed, high_res_feats_0, high_res_feats_1, - point_coords, point_labels, mask_input, has_mask_input)) - torch.jit.save(converted, f"model_name/model_name.pt") + converted = torch.jit.trace_module( + model, { + "encode": + input_image, + "forward": (image_embed, high_res_feats_0, high_res_feats_1, + point_coords, point_labels, mask_input, has_mask_input) + }) + torch.jit.save(converted, f"{model_name}/{model_name}.pt") # save serving.properties serving_file = os.path.join(model_name, "serving.properties") @@ -244,7 +238,7 @@ def trace_model(model_id: str): f"engine=PyTorch\n" f"option.modelName={model_name}\n" f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n" - f"encoder=encoder.pt") + f"encode_method=encode\n") if __name__ == '__main__': diff --git a/examples/docs/trace_sam2_img.py b/examples/docs/trace_sam2_img.py index ce5a94ed0ce..650284cc101 100644 --- a/examples/docs/trace_sam2_img.py +++ b/examples/docs/trace_sam2_img.py @@ -20,15 +20,21 @@ from torch import nn -class SAM2ImageEncoder(nn.Module): +class Sam2Wrapper(nn.Module): - def __init__(self, sam_model: SAM2Base) -> None: + def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: super().__init__() self.model = sam_model self.image_encoder = sam_model.image_encoder self.no_mem_embed = sam_model.no_mem_embed + self.mask_decoder = sam_model.sam_mask_decoder + self.prompt_encoder = sam_model.sam_prompt_encoder + self.img_size = sam_model.image_size + self.multimask_output = multimask_output + self.sparse_embedding = None - def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tuple[Any, Any, Any]: backbone_out = self.image_encoder(x) backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0( backbone_out["backbone_fpn"][0]) @@ -53,18 +59,6 @@ def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: return feats[0], feats[1], feats[2] - -class SAM2ImageDecoder(nn.Module): - - def __init__(self, sam_model: SAM2Base, multimask_output: bool) -> None: - super().__init__() - self.mask_decoder = sam_model.sam_mask_decoder - self.prompt_encoder = sam_model.sam_prompt_encoder - self.model = sam_model - self.img_size = sam_model.image_size - self.multimask_output = multimask_output - self.sparse_embedding = None - @torch.no_grad() def forward( self, @@ -152,17 +146,13 @@ def trace_model(model_id: str): device = torch.device("cpu") model_name = f"{model_id[9:]}" - os.makedirs(model_name) + os.makedirs(model_name, exist_ok=True) predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device) - encoder = SAM2ImageEncoder(predictor.model) - decoder = SAM2ImageDecoder(predictor.model, True) + model = Sam2Wrapper(predictor.model, True) input_image = torch.ones(1, 3, 1024, 1024).to(device) - high_res_feats_0, high_res_feats_1, image_embed = encoder(input_image) - - converted = torch.jit.trace(encoder, input_image) - torch.jit.save(converted, f"model_name/encoder.pt") + high_res_feats_0, high_res_feats_1, image_embed = model.encode(input_image) # trace decoder model embed_size = ( @@ -179,10 +169,14 @@ def trace_model(model_id: str): mask_input = torch.randn(1, 1, *mask_input_size, dtype=torch.float) has_mask_input = torch.tensor([1], dtype=torch.float) - converted = torch.jit.trace( - decoder, (image_embed, high_res_feats_0, high_res_feats_1, - point_coords, point_labels, mask_input, has_mask_input)) - torch.jit.save(converted, f"model_name/model_name.pt") + converted = torch.jit.trace_module( + model, { + "encode": + input_image, + "forward": (image_embed, high_res_feats_0, high_res_feats_1, + point_coords, point_labels, mask_input, has_mask_input) + }) + torch.jit.save(converted, f"{model_name}/{model_name}.pt") # save serving.properties serving_file = os.path.join(model_name, "serving.properties") @@ -191,7 +185,7 @@ def trace_model(model_id: str): f"engine=PyTorch\n" f"option.modelName={model_name}\n" f"translatorFactory=ai.djl.modality.cv.translator.Sam2TranslatorFactory\n" - f"encoder=encoder.pt") + f"encode_method=encode\n") if __name__ == '__main__':