diff --git a/examples/README.md b/examples/README.md
index d527c7de244..6161027be74 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -15,6 +15,7 @@ The following examples are included for training:
 The following examples are included for inference:
 
 - [Image classification example](docs/image_classification.md)
+- [Segment anything 2 example](docs/segment_anything_2.md)
 - [Single-shot object detection example](docs/object_detection.md)
 - [Face detection example](docs/face_detection.md)
 - [Face recognition example](docs/face_recognition.md)
diff --git a/examples/docs/segment_anything_2.md b/examples/docs/segment_anything_2.md
new file mode 100644
index 00000000000..4ffc11b9eab
--- /dev/null
+++ b/examples/docs/segment_anything_2.md
@@ -0,0 +1,173 @@
+# Segment anything 2 example
+
+[Mask generation](https://huggingface.co/tasks/mask-generation) is the task of generating masks that
+identify a specific object or region of interest in a given image.
+
+In this example, you learn how to implement inference code with a [ModelZoo model](../../docs/model-zoo.md) to
+generate mask of a selected object in an image.
+
+The source code can be found
+at [SegmentAnything2.java](https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/cv/SegmentAnything2.java).
+
+## Setup guide
+
+To configure your development environment, follow [setup](../../docs/development/setup.md).
+
+## Run segment anything 2 example
+
+### Input image file
+
+You can find the image used in this example:
+
+![truck](https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg)
+
+### Build the project and run
+
+Use the following command to run the project:
+
+```sh
+cd examples
+./gradlew run -Dmain=ai.djl.examples.inference.cv.SegmentAnything2
+```
+
+Your output should look like the following:
+
+```text
+[INFO ] - Number of inter-op threads is 12
+[INFO ] - Number of intra-op threads is 6
+[INFO ] - Segmentation result image has been saved in: build/output/sam2.png
+[INFO ] - [
+        {"class": "", "probability": 0.92789, "bounds": {"x"=0.000, "y"=0.000, "width"=1800.000, "height"=1200.000}}
+]
+```
+
+An output image with bounding box will be saved as build/output/sam2.png:
+
+![mask](https://resources.djl.ai/images/sam2_truck_1.png)
+
+## Reference - how to import pytorch model:
+
+The original model can be found:
+
+- [sam2-hiera-large](https://huggingface.co/facebook/sam2-hiera-large)
+- [sam2-hiera-tiny](https://huggingface.co/facebook/sam2-hiera-tiny)
+
+The model zoo model was traced with `sam2==0.4.1` and `transformers==4.43.4`
+
+### install dependencies
+
+```bash
+pip install sam2 transformers
+```
+
+### trace the model
+
+```python
+import sys
+from typing import Tuple
+
+import torch
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from torch import nn
+
+
+class Sam2Wrapper(nn.Module):
+
+    def __init__(
+        self,
+        sam_model: SAM2Base,
+    ) -> None:
+        super().__init__()
+        self.model = sam_model
+
+        # Spatial dim for backbone feature maps
+        self._bb_feat_sizes = [
+            (256, 256),
+            (128, 128),
+            (64, 64),
+        ]
+
+    def extract_features(
+        self,
+        input_image: torch.Tensor,
+    ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+        backbone_out = self.model.forward_image(input_image)
+        _, vision_feats, _, _ = self.model._prepare_backbone_features(
+            backbone_out)
+        # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+        if self.model.directly_add_no_mem_embed:
+            vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+        feats = [
+            feat.permute(1, 2,
+                         0).view(1, -1, *feat_size) for feat, feat_size in zip(
+                             vision_feats[::-1], self._bb_feat_sizes[::-1])
+        ][::-1]
+
+        return feats[-1], feats[0], feats[1]
+
+    def forward(
+        self,
+        input_image: torch.Tensor,
+        point_coords: torch.Tensor,
+        point_labels: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        image_embed, feature_1, feature_2 = self.extract_features(input_image)
+        return self.predict(point_coords, point_labels, image_embed, feature_1,
+                            feature_2)
+
+    def predict(
+        self,
+        point_coords: torch.Tensor,
+        point_labels: torch.Tensor,
+        image_embed: torch.Tensor,
+        feats_1: torch.Tensor,
+        feats_2: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        concat_points = (point_coords, point_labels)
+
+        sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+            points=concat_points,
+            boxes=None,
+            masks=None,
+        )
+
+        low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
+            image_embeddings=image_embed[0].unsqueeze(0),
+            image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=True,
+            repeat_image=False,
+            high_res_features=[feats_1, feats_2],
+        )
+        return low_res_masks, iou_predictions
+
+
+def trace_model(model_id: str):
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device)
+    model = Sam2Wrapper(predictor.model)
+
+    input_image = torch.ones(1, 3, 1024, 1024).to(device)
+    input_point = torch.ones(1, 1, 2).to(device)
+    input_labels = torch.ones(1, 1, dtype=torch.int32, device=device)
+
+    converted = torch.jit.trace_module(
+        model, {
+            "extract_features": input_image,
+            "forward": (input_image, input_point, input_labels)
+        })
+    torch.jit.save(converted, f"{model_id[9:]}.pt")
+
+
+if __name__ == '__main__':
+    hf_model_id = sys.argv[1] if len(
+        sys.argv) > 1 else "facebook/sam2-hiera-tiny"
+    trace_model(hf_model_id)
+```
diff --git a/examples/docs/trace_sam2_img.py b/examples/docs/trace_sam2_img.py
new file mode 100644
index 00000000000..bb80795202b
--- /dev/null
+++ b/examples/docs/trace_sam2_img.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python
+#
+# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
+# except in compliance with the License. A copy of the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
+# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
+# the specific language governing permissions and limitations under the License.
+import sys
+from typing import Tuple
+
+import torch
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from torch import nn
+
+
+class Sam2Wrapper(nn.Module):
+
+    def __init__(
+        self,
+        sam_model: SAM2Base,
+    ) -> None:
+        super().__init__()
+        self.model = sam_model
+
+        # Spatial dim for backbone feature maps
+        self._bb_feat_sizes = [
+            (256, 256),
+            (128, 128),
+            (64, 64),
+        ]
+
+    def extract_features(
+        self,
+        input_image: torch.Tensor,
+    ) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+        backbone_out = self.model.forward_image(input_image)
+        _, vision_feats, _, _ = self.model._prepare_backbone_features(
+            backbone_out)
+        # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+        if self.model.directly_add_no_mem_embed:
+            vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+        feats = [
+            feat.permute(1, 2,
+                         0).view(1, -1, *feat_size) for feat, feat_size in zip(
+                             vision_feats[::-1], self._bb_feat_sizes[::-1])
+        ][::-1]
+
+        return feats[-1], feats[0], feats[1]
+
+    def forward(
+        self,
+        input_image: torch.Tensor,
+        point_coords: torch.Tensor,
+        point_labels: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        image_embed, feature_1, feature_2 = self.extract_features(input_image)
+        return self.predict(point_coords, point_labels, image_embed, feature_1,
+                            feature_2)
+
+    def predict(
+        self,
+        point_coords: torch.Tensor,
+        point_labels: torch.Tensor,
+        image_embed: torch.Tensor,
+        feats_1: torch.Tensor,
+        feats_2: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        concat_points = (point_coords, point_labels)
+
+        sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+            points=concat_points,
+            boxes=None,
+            masks=None,
+        )
+
+        low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
+            image_embeddings=image_embed[0].unsqueeze(0),
+            image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=True,
+            repeat_image=False,
+            high_res_features=[feats_1, feats_2],
+        )
+        return low_res_masks, iou_predictions
+
+
+def trace_model(model_id: str):
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    predictor = SAM2ImagePredictor.from_pretrained(model_id, device=device)
+    model = Sam2Wrapper(predictor.model)
+
+    input_image = torch.ones(1, 3, 1024, 1024).to(device)
+    input_point = torch.ones(1, 1, 2).to(device)
+    input_labels = torch.ones(1, 1, dtype=torch.int32, device=device)
+
+    converted = torch.jit.trace_module(
+        model, {
+            "extract_features": input_image,
+            "forward": (input_image, input_point, input_labels)
+        })
+    torch.jit.save(converted, f"{model_id[9:]}.pt")
+
+
+if __name__ == '__main__':
+    hf_model_id = sys.argv[1] if len(
+        sys.argv) > 1 else "facebook/sam2-hiera-tiny"
+    trace_model(hf_model_id)